diff --git a/docs/contributing/MODEL_MANAGER.md b/docs/contributing/MODEL_MANAGER.md index 880c8b24801..8351904b619 100644 --- a/docs/contributing/MODEL_MANAGER.md +++ b/docs/contributing/MODEL_MANAGER.md @@ -28,7 +28,7 @@ model. These are the: Hugging Face, as well as discriminating among model versions in Civitai, but can be used for arbitrary content. - * _ModelLoadServiceBase_ (**CURRENTLY UNDER DEVELOPMENT - NOT IMPLEMENTED**) + * _ModelLoadServiceBase_ Responsible for loading a model from disk into RAM and VRAM and getting it ready for inference. @@ -41,10 +41,10 @@ The four main services can be found in * `invokeai/app/services/model_records/` * `invokeai/app/services/model_install/` * `invokeai/app/services/downloads/` -* `invokeai/app/services/model_loader/` (**under development**) +* `invokeai/app/services/model_load/` Code related to the FastAPI web API can be found in -`invokeai/app/api/routers/model_records.py`. +`invokeai/app/api/routers/model_manager_v2.py`. *** @@ -84,10 +84,10 @@ diffusers model. When this happens, `original_hash` is unchanged, but `ModelType`, `ModelFormat` and `BaseModelType` are string enums that are defined in `invokeai.backend.model_manager.config`. They are also imported by, and can be reexported from, -`invokeai.app.services.model_record_service`: +`invokeai.app.services.model_manager.model_records`: ``` -from invokeai.app.services.model_record_service import ModelType, ModelFormat, BaseModelType +from invokeai.app.services.model_records import ModelType, ModelFormat, BaseModelType ``` The `path` field can be absolute or relative. If relative, it is taken @@ -123,7 +123,7 @@ taken to be the `models_dir` directory. `variant` is an enumerated string class with values `normal`, `inpaint` and `depth`. If needed, it can be imported if needed from -either `invokeai.app.services.model_record_service` or +either `invokeai.app.services.model_records` or `invokeai.backend.model_manager.config`. ### ONNXSD2Config @@ -134,7 +134,7 @@ either `invokeai.app.services.model_record_service` or | `upcast_attention` | bool | Model requires its attention module to be upcast | The `SchedulerPredictionType` enum can be imported from either -`invokeai.app.services.model_record_service` or +`invokeai.app.services.model_records` or `invokeai.backend.model_manager.config`. ### Other config classes @@ -157,15 +157,6 @@ indicates that the model is compatible with any of the base models. This works OK for some models, such as the IP Adapter image encoders, but is an all-or-nothing proposition. -Another issue is that the config class hierarchy is paralleled to some -extent by a `ModelBase` class hierarchy defined in -`invokeai.backend.model_manager.models.base` and its subclasses. These -are classes representing the models after they are loaded into RAM and -include runtime information such as load status and bytes used. Some -of the fields, including `name`, `model_type` and `base_model`, are -shared between `ModelConfigBase` and `ModelBase`, and this is a -potential source of confusion. - ## Reading and Writing Model Configuration Records The `ModelRecordService` provides the ability to retrieve model @@ -177,11 +168,11 @@ initialization and can be retrieved within an invocation from the `InvocationContext` object: ``` -store = context.services.model_record_store +store = context.services.model_manager.store ``` or from elsewhere in the code by accessing -`ApiDependencies.invoker.services.model_record_store`. +`ApiDependencies.invoker.services.model_manager.store`. ### Creating a `ModelRecordService` @@ -190,7 +181,7 @@ you can directly create either a `ModelRecordServiceSQL` or a `ModelRecordServiceFile` object: ``` -from invokeai.app.services.model_record_service import ModelRecordServiceSQL, ModelRecordServiceFile +from invokeai.app.services.model_records import ModelRecordServiceSQL, ModelRecordServiceFile store = ModelRecordServiceSQL.from_connection(connection, lock) store = ModelRecordServiceSQL.from_db_file('/path/to/sqlite_database.db') @@ -252,7 +243,7 @@ So a typical startup pattern would be: ``` import sqlite3 from invokeai.app.services.thread import lock -from invokeai.app.services.model_record_service import ModelRecordServiceBase +from invokeai.app.services.model_records import ModelRecordServiceBase from invokeai.app.services.config import InvokeAIAppConfig config = InvokeAIAppConfig.get_config() @@ -260,19 +251,6 @@ db_conn = sqlite3.connect(config.db_path.as_posix(), check_same_thread=False) store = ModelRecordServiceBase.open(config, db_conn, lock) ``` -_A note on simultaneous access to `invokeai.db`_: The current InvokeAI -service architecture for the image and graph databases is careful to -use a shared sqlite3 connection and a thread lock to ensure that two -threads don't attempt to access the database simultaneously. However, -the default `sqlite3` library used by Python reports using -**Serialized** mode, which allows multiple threads to access the -database simultaneously using multiple database connections (see -https://www.sqlite.org/threadsafe.html and -https://ricardoanderegg.com/posts/python-sqlite-thread-safety/). Therefore -it should be safe to allow the record service to open its own SQLite -database connection. Opening a model record service should then be as -simple as `ModelRecordServiceBase.open(config)`. - ### Fetching a Model's Configuration from `ModelRecordServiceBase` Configurations can be retrieved in several ways. @@ -468,6 +446,44 @@ required parameters: Once initialized, the installer will provide the following methods: +#### install_job = installer.heuristic_import(source, [config], [access_token]) + +This is a simplified interface to the installer which takes a source +string, an optional model configuration dictionary and an optional +access token. + +The `source` is a string that can be any of these forms + +1. A path on the local filesystem (`C:\\users\\fred\\model.safetensors`) +2. A Url pointing to a single downloadable model file (`https://civitai.com/models/58390/detail-tweaker-lora-lora`) +3. A HuggingFace repo_id with any of the following formats: + - `model/name` -- entire model + - `model/name:fp32` -- entire model, using the fp32 variant + - `model/name:fp16:vae` -- vae submodel, using the fp16 variant + - `model/name::vae` -- vae submodel, using default precision + - `model/name:fp16:path/to/model.safetensors` -- an individual model file, fp16 variant + - `model/name::path/to/model.safetensors` -- an individual model file, default variant + +Note that by specifying a relative path to the top of the HuggingFace +repo, you can download and install arbitrary models files. + +The variant, if not provided, will be automatically filled in with +`fp32` if the user has requested full precision, and `fp16` +otherwise. If a variant that does not exist is requested, then the +method will install whatever HuggingFace returns as its default +revision. + +`config` is an optional dict of values that will override the +autoprobed values for model type, base, scheduler prediction type, and +so forth. See [Model configuration and +probing](#Model-configuration-and-probing) for details. + +`access_token` is an optional access token for accessing resources +that need authentication. + +The method will return a `ModelInstallJob`. This object is discussed +at length in the following section. + #### install_job = installer.import_model() The `import_model()` method is the core of the installer. The @@ -486,9 +502,10 @@ source2 = LocalModelSource(path='/opt/models/sushi_diffusers') # a local dif source3 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5') # a repo_id source4 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', subfolder='vae') # a subfolder within a repo_id source5 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', variant='fp16') # a named variant of a HF model +source6 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', subfolder='OrangeMix/OrangeMix1.ckpt') # path to an individual model file -source6 = URLModelSource(url='https://civitai.com/api/download/models/63006') # model located at a URL -source7 = URLModelSource(url='https://civitai.com/api/download/models/63006', access_token='letmein') # with an access token +source7 = URLModelSource(url='https://civitai.com/api/download/models/63006') # model located at a URL +source8 = URLModelSource(url='https://civitai.com/api/download/models/63006', access_token='letmein') # with an access token for source in [source1, source2, source3, source4, source5, source6, source7]: install_job = installer.install_model(source) @@ -544,7 +561,6 @@ can be passed to `import_model()`. attributes returned by the model prober. See the section below for details. - #### LocalModelSource This is used for a model that is located on a locally-accessible Posix @@ -737,7 +753,7 @@ and `cancelled`, as well as `in_terminal_state`. The last will return True if the job is in the complete, errored or cancelled states. -#### Model confguration and probing +#### Model configuration and probing The install service uses the `invokeai.backend.model_manager.probe` module during import to determine the model's type, base type, and @@ -776,6 +792,14 @@ returns a list of completed jobs. The optional `timeout` argument will return from the call if jobs aren't completed in the specified time. An argument of 0 (the default) will block indefinitely. +#### jobs = installer.wait_for_job(job, [timeout]) + +Like `wait_for_installs()`, but block until a specific job has +completed or errored, and then return the job. The optional `timeout` +argument will return from the call if the job doesn't complete in the +specified time. An argument of 0 (the default) will block +indefinitely. + #### jobs = installer.list_jobs() Return a list of all active and complete `ModelInstallJobs`. @@ -838,6 +862,31 @@ This method is similar to `unregister()`, but also unconditionally deletes the corresponding model weights file(s), regardless of whether they are inside or outside the InvokeAI models hierarchy. + +#### path = installer.download_and_cache(remote_source, [access_token], [timeout]) + +This utility routine will download the model file located at source, +cache it, and return the path to the cached file. It does not attempt +to determine the model type, probe its configuration values, or +register it with the models database. + +You may provide an access token if the remote source requires +authorization. The call will block indefinitely until the file is +completely downloaded, cancelled or raises an error of some sort. If +you provide a timeout (in seconds), the call will raise a +`TimeoutError` exception if the download hasn't completed in the +specified period. + +You may use this mechanism to request any type of file, not just a +model. The file will be stored in a subdirectory of +`INVOKEAI_ROOT/models/.cache`. If the requested file is found in the +cache, its path will be returned without redownloading it. + +Be aware that the models cache is cleared of infrequently-used files +and directories at regular intervals when the size of the cache +exceeds the value specified in Invoke's `convert_cache` configuration +variable. + #### List[str]=installer.scan_directory(scan_dir: Path, install: bool) This method will recursively scan the directory indicated in @@ -1128,7 +1177,7 @@ job = queue.create_download_job( event_handlers=[my_handler1, my_handler2], # if desired start=True, ) - ``` +``` The `filename` argument forces the downloader to use the specified name for the file rather than the name provided by the remote source, @@ -1171,6 +1220,13 @@ queue or was not created by this queue. This method will block until all the active jobs in the queue have reached a terminal state (completed, errored or cancelled). +#### queue.wait_for_job(job, [timeout]) + +This method will block until the indicated job has reached a terminal +state (completed, errored or cancelled). If the optional timeout is +provided, the call will block for at most timeout seconds, and raise a +TimeoutError otherwise. + #### jobs = queue.list_jobs() This will return a list of all jobs, including ones that have not yet @@ -1449,9 +1505,9 @@ set of keys to the corresponding model config objects. Find all model metadata records that have the given author and return a set of keys to the corresponding model config objects. -# The remainder of this documentation is provisional, pending implementation of the Load service +*** -## Let's get loaded, the lowdown on ModelLoadService +## The Lowdown on the ModelLoadService The `ModelLoadService` is responsible for loading a named model into memory so that it can be used for inference. Despite the fact that it @@ -1465,7 +1521,7 @@ create alternative instances if you wish. ### Creating a ModelLoadService object The class is defined in -`invokeai.app.services.model_loader_service`. It is initialized with +`invokeai.app.services.model_load`. It is initialized with an InvokeAIAppConfig object, from which it gets configuration information such as the user's desired GPU and precision, and with a previously-created `ModelRecordServiceBase` object, from which it @@ -1475,26 +1531,29 @@ Here is a typical initialization pattern: ``` from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.model_record_service import ModelRecordServiceBase -from invokeai.app.services.model_loader_service import ModelLoadService +from invokeai.app.services.model_load import ModelLoadService, ModelLoaderRegistry config = InvokeAIAppConfig.get_config() -store = ModelRecordServiceBase.open(config) -loader = ModelLoadService(config, store) +ram_cache = ModelCache( + max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger +) +convert_cache = ModelConvertCache( + cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size +) +loader = ModelLoadService( + app_config=config, + ram_cache=ram_cache, + convert_cache=convert_cache, + registry=ModelLoaderRegistry +) ``` -Note that we are relying on the contents of the application -configuration to choose the implementation of -`ModelRecordServiceBase`. - -### get_model(key, [submodel_type], [context]) -> ModelInfo: - -*** TO DO: change to get_model(key, context=None, **kwargs) +### load_model(model_config, [submodel_type], [context]) -> LoadedModel -The `get_model()` method, like its similarly-named cousin in -`ModelRecordService`, receives the unique key that identifies the +The `load_model()` method takes an `AnyModelConfig` returned by +`ModelRecordService.get_model()` and returns the corresponding loaded model. It loads the model into memory, gets the model ready for use, -and returns a `ModelInfo` object. +and returns a `LoadedModel` object. The optional second argument, `subtype` is a `SubModelType` string enum, such as "vae". It is mandatory when used with a main model, and @@ -1504,46 +1563,45 @@ The optional third argument, `context` can be provided by an invocation to trigger model load event reporting. See below for details. -The returned `ModelInfo` object shares some fields in common with -`ModelConfigBase`, but is otherwise a completely different beast: +The returned `LoadedModel` object contains a copy of the configuration +record returned by the model record `get_model()` method, as well as +the in-memory loaded model: -| **Field Name** | **Type** | **Description** | + +| **Attribute Name** | **Type** | **Description** | |----------------|-----------------|------------------| -| `key` | str | The model key derived from the ModelRecordService database | -| `name` | str | Name of this model | -| `base_model` | BaseModelType | Base model for this model | -| `type` | ModelType or SubModelType | Either the model type (non-main) or the submodel type (main models)| -| `location` | Path or str | Location of the model on the filesystem | -| `precision` | torch.dtype | The torch.precision to use for inference | -| `context` | ModelCache.ModelLocker | A context class used to lock the model in VRAM while in use | +| `config` | AnyModelConfig | A copy of the model's configuration record for retrieving base type, etc. | +| `model` | AnyModel | The instantiated model (details below) | +| `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM | -The types for `ModelInfo` and `SubModelType` can be imported from -`invokeai.app.services.model_loader_service`. +Because the loader can return multiple model types, it is typed to +return `AnyModel`, a Union `ModelMixin`, `torch.nn.Module`, +`IAIOnnxRuntimeModel`, `IPAdapter`, `IPAdapterPlus`, and +`EmbeddingModelRaw`. `ModelMixin` is the base class of all diffusers +models, `EmbeddingModelRaw` is used for LoRA and TextualInversion +models. The others are obvious. -To use the model, you use the `ModelInfo` as a context manager using -the following pattern: + +`LoadedModel` acts as a context manager. The context loads the model +into the execution device (e.g. VRAM on CUDA systems), locks the model +in the execution device for the duration of the context, and returns +the model. Use it like this: ``` -model_info = loader.get_model('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae')) +model_info = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae')) with model_info as vae: image = vae.decode(latents)[0] ``` -The `vae` model will stay locked in the GPU during the period of time -it is in the context manager's scope. - -`get_model()` may raise any of the following exceptions: +`get_model_by_key()` may raise any of the following exceptions: -- `UnknownModelException` -- key not in database -- `ModelNotFoundException` -- key in database but model not found at path -- `InvalidModelException` -- the model is guilty of a variety of sins +- `UnknownModelException` -- key not in database +- `ModelNotFoundException` -- key in database but model not found at path +- `NotImplementedException` -- the loader doesn't know how to load this type of model -** TO DO: ** Resolve discrepancy between ModelInfo.location and -ModelConfig.path. - ### Emitting model loading events -When the `context` argument is passed to `get_model()`, it will +When the `context` argument is passed to `load_model_*()`, it will retrieve the invocation event bus from the passed `InvocationContext` object to emit events on the invocation bus. The two events are "model_load_started" and "model_load_completed". Both carry the @@ -1556,10 +1614,174 @@ payload=dict( queue_batch_id=queue_batch_id, graph_execution_state_id=graph_execution_state_id, model_key=model_key, - submodel=submodel, + submodel_type=submodel, hash=model_info.hash, location=str(model_info.location), precision=str(model_info.precision), ) ``` +### Adding Model Loaders + +Model loaders are small classes that inherit from the `ModelLoader` +base class. They typically implement one method `_load_model()` whose +signature is: + +``` +def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, +) -> AnyModel: +``` + +`_load_model()` will be passed the path to the model on disk, an +optional repository variant (used by the diffusers loaders to select, +e.g. the `fp16` variant, and an optional submodel_type for main and +onnx models. + +To install a new loader, place it in +`invokeai/backend/model_manager/load/model_loaders`. Inherit from +`ModelLoader` and use the `@ModelLoaderRegistry.register()` decorator to +indicate what type of models the loader can handle. + +Here is a complete example from `generic_diffusers.py`, which is able +to load several different diffusers types: + +``` +from pathlib import Path +from typing import Optional + +from invokeai.backend.model_manager import ( + AnyModel, + BaseModelType, + ModelFormat, + ModelRepoVariant, + ModelType, + SubModelType, +) +from .. import ModelLoader, ModelLoaderRegistry + + +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers) +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers) +class GenericDiffusersLoader(ModelLoader): + """Class to load simple diffusers models.""" + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + model_class = self._get_hf_load_class(model_path) + if submodel_type is not None: + raise Exception(f"There are no submodels in models of type {model_class}") + variant = model_variant.value if model_variant else None + result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant) # type: ignore + return result +``` + +Note that a loader can register itself to handle several different +model types. An exception will be raised if more than one loader tries +to register the same model type. + +#### Conversion + +Some models require conversion to diffusers format before they can be +loaded. These loaders should override two additional methods: + +``` +_needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool +_convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path: +``` + +The first method accepts the model configuration, the path to where +the unmodified model is currently installed, and a proposed +destination for the converted model. This method returns True if the +model needs to be converted. It typically does this by comparing the +last modification time of the original model file to the modification +time of the converted model. In some cases you will also want to check +the modification date of the configuration record, in the event that +the user has changed something like the scheduler prediction type that +will require the model to be re-converted. See `controlnet.py` for an +example of this logic. + +The second method accepts the model configuration, the path to the +original model on disk, and the desired output path for the converted +model. It does whatever it needs to do to get the model into diffusers +format, and returns the Path of the resulting model. (The path should +ordinarily be the same as `output_path`.) + +## The ModelManagerService object + +For convenience, the API provides a `ModelManagerService` object which +gives a single point of access to the major model manager +services. This object is created at initialization time and can be +found in the global `ApiDependencies.invoker.services.model_manager` +object, or in `context.services.model_manager` from within an +invocation. + +In the examples below, we have retrieved the manager using: +``` +mm = ApiDependencies.invoker.services.model_manager +``` + +The following properties and methods will be available: + +### mm.store + +This retrieves the `ModelRecordService` associated with the +manager. Example: + +``` +configs = mm.store.get_model_by_attr(name='stable-diffusion-v1-5') +``` + +### mm.install + +This retrieves the `ModelInstallService` associated with the manager. +Example: + +``` +job = mm.install.heuristic_import(`https://civitai.com/models/58390/detail-tweaker-lora-lora`) +``` + +### mm.load + +This retrieves the `ModelLoaderService` associated with the manager. Example: + +``` +configs = mm.store.get_model_by_attr(name='stable-diffusion-v1-5') +assert len(configs) > 0 + +loaded_model = mm.load.load_model(configs[0]) +``` + +The model manager also offers a few convenience shortcuts for loading +models: + +### mm.load_model_by_config(model_config, [submodel], [context]) -> LoadedModel + +Same as `mm.load.load_model()`. + +### mm.load_model_by_attr(model_name, base_model, model_type, [submodel], [context]) -> LoadedModel + +This accepts the combination of the model's name, type and base, which +it passes to the model record config store for retrieval. If a unique +model config is found, this method returns a `LoadedModel`. It can +raise the following exceptions: + +``` +UnknownModelException -- model with these attributes not known +NotImplementedException -- the loader doesn't know how to load this type of model +ValueError -- more than one model matches this combination of base/type/name +``` + +### mm.load_model_by_key(key, [submodel], [context]) -> LoadedModel + +This method takes a model key, looks it up using the +`ModelRecordServiceBase` object in `mm.store`, and passes the returned +model configuration to `load_model_by_config()`. It may raise a +`NotImplementedException`. diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index c8309e1729e..264f42a9bde 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -4,7 +4,6 @@ from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory from invokeai.app.services.shared.sqlite.sqlite_util import init_db -from invokeai.backend.model_manager.metadata import ModelMetadataStore from invokeai.backend.util.logging import InvokeAILogger from invokeai.version.invokeai_version import __version__ @@ -25,8 +24,8 @@ from ..services.invoker import Invoker from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage -from ..services.model_install import ModelInstallService from ..services.model_manager.model_manager_default import ModelManagerService +from ..services.model_metadata import ModelMetadataStoreSQL from ..services.model_records import ModelRecordServiceSQL from ..services.names.names_default import SimpleNameService from ..services.session_processor.session_processor_default import DefaultSessionProcessor @@ -85,16 +84,13 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger images = ImageService() invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")) - model_manager = ModelManagerService(config, logger) - model_record_service = ModelRecordServiceSQL(db=db) download_queue_service = DownloadQueueService(event_bus=events) - metadata_store = ModelMetadataStore(db=db) - model_install_service = ModelInstallService( - app_config=config, - record_store=model_record_service, + model_metadata_service = ModelMetadataStoreSQL(db=db) + model_manager = ModelManagerService.build_model_manager( + app_config=configuration, + model_record_service=ModelRecordServiceSQL(db=db, metadata_store=model_metadata_service), download_queue=download_queue_service, - metadata_store=metadata_store, - event_bus=events, + events=events, ) names = SimpleNameService() performance_statistics = InvocationStatsService() @@ -120,9 +116,7 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger latents=latents, logger=logger, model_manager=model_manager, - model_records=model_record_service, download_queue=download_queue_service, - model_install=model_install_service, names=names, performance_statistics=performance_statistics, processor=processor, diff --git a/invokeai/app/api/routers/download_queue.py b/invokeai/app/api/routers/download_queue.py index 92b658c3708..a6e53c7a5c4 100644 --- a/invokeai/app/api/routers/download_queue.py +++ b/invokeai/app/api/routers/download_queue.py @@ -36,7 +36,7 @@ async def list_downloads() -> List[DownloadJob]: 400: {"description": "Bad request"}, }, ) -async def prune_downloads(): +async def prune_downloads() -> Response: """Prune completed and errored jobs.""" queue = ApiDependencies.invoker.services.download_queue queue.prune_jobs() @@ -55,7 +55,7 @@ async def download( ) -> DownloadJob: """Download the source URL to the file or directory indicted in dest.""" queue = ApiDependencies.invoker.services.download_queue - return queue.download(source, dest, priority, access_token) + return queue.download(source, Path(dest), priority, access_token) @download_queue_router.get( @@ -87,7 +87,7 @@ async def get_download_job( ) async def cancel_download_job( id: int = Path(description="ID of the download job to cancel."), -): +) -> Response: """Cancel a download job using its ID.""" try: queue = ApiDependencies.invoker.services.download_queue @@ -105,7 +105,7 @@ async def cancel_download_job( 204: {"description": "Download jobs have been cancelled"}, }, ) -async def cancel_all_download_jobs(): +async def cancel_all_download_jobs() -> Response: """Cancel all download jobs.""" ApiDependencies.invoker.services.download_queue.cancel_all_jobs() return Response(status_code=204) diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py new file mode 100644 index 00000000000..6b7111dd2ce --- /dev/null +++ b/invokeai/app/api/routers/model_manager.py @@ -0,0 +1,759 @@ +# Copyright (c) 2023 Lincoln D. Stein +"""FastAPI route for model configuration records.""" + +import pathlib +import shutil +from hashlib import sha1 +from random import randbytes +from typing import Any, Dict, List, Optional, Set + +from fastapi import Body, Path, Query, Response +from fastapi.routing import APIRouter +from pydantic import BaseModel, ConfigDict +from starlette.exceptions import HTTPException +from typing_extensions import Annotated + +from invokeai.app.services.model_install import ModelInstallJob, ModelSource +from invokeai.app.services.model_records import ( + DuplicateModelException, + InvalidModelException, + ModelRecordOrderBy, + ModelSummary, + UnknownModelException, +) +from invokeai.app.services.shared.pagination import PaginatedResults +from invokeai.backend.model_manager.config import ( + AnyModelConfig, + BaseModelType, + MainCheckpointConfig, + ModelFormat, + ModelType, + SubModelType, +) +from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger +from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata + +from ..dependencies import ApiDependencies + +model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"]) + + +class ModelsList(BaseModel): + """Return list of configs.""" + + models: List[AnyModelConfig] + + model_config = ConfigDict(use_enum_values=True) + + +class ModelTagSet(BaseModel): + """Return tags for a set of models.""" + + key: str + name: str + author: str + tags: Set[str] + + +############################################################################## +# These are example inputs and outputs that are used in places where Swagger +# is unable to generate a correct example. +############################################################################## +example_model_config = { + "path": "string", + "name": "string", + "base": "sd-1", + "type": "main", + "format": "checkpoint", + "config": "string", + "key": "string", + "original_hash": "string", + "current_hash": "string", + "description": "string", + "source": "string", + "last_modified": 0, + "vae": "string", + "variant": "normal", + "prediction_type": "epsilon", + "repo_variant": "fp16", + "upcast_attention": False, + "ztsnr_training": False, +} + +example_model_input = { + "path": "/path/to/model", + "name": "model_name", + "base": "sd-1", + "type": "main", + "format": "checkpoint", + "config": "configs/stable-diffusion/v1-inference.yaml", + "description": "Model description", + "vae": None, + "variant": "normal", +} + +example_model_metadata = { + "name": "ip_adapter_sd_image_encoder", + "author": "InvokeAI", + "tags": [ + "transformers", + "safetensors", + "clip_vision_model", + "endpoints_compatible", + "region:us", + "has_space", + "license:apache-2.0", + ], + "files": [ + { + "url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/README.md", + "path": "ip_adapter_sd_image_encoder/README.md", + "size": 628, + "sha256": None, + }, + { + "url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/config.json", + "path": "ip_adapter_sd_image_encoder/config.json", + "size": 560, + "sha256": None, + }, + { + "url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/model.safetensors", + "path": "ip_adapter_sd_image_encoder/model.safetensors", + "size": 2528373448, + "sha256": "6ca9667da1ca9e0b0f75e46bb030f7e011f44f86cbfb8d5a36590fcd7507b030", + }, + ], + "type": "huggingface", + "id": "InvokeAI/ip_adapter_sd_image_encoder", + "tag_dict": {"license": "apache-2.0"}, + "last_modified": "2023-09-23T17:33:25Z", +} + +############################################################################## +# ROUTES +############################################################################## + + +@model_manager_router.get( + "/", + operation_id="list_model_records", +) +async def list_model_records( + base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"), + model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"), + model_name: Optional[str] = Query(default=None, description="Exact match on the name of the model"), + model_format: Optional[ModelFormat] = Query( + default=None, description="Exact match on the format of the model (e.g. 'diffusers')" + ), +) -> ModelsList: + """Get a list of models.""" + record_store = ApiDependencies.invoker.services.model_manager.store + found_models: list[AnyModelConfig] = [] + if base_models: + for base_model in base_models: + found_models.extend( + record_store.search_by_attr( + base_model=base_model, model_type=model_type, model_name=model_name, model_format=model_format + ) + ) + else: + found_models.extend( + record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format) + ) + return ModelsList(models=found_models) + + +@model_manager_router.get( + "/i/{key}", + operation_id="get_model_record", + responses={ + 200: { + "description": "The model configuration was retrieved successfully", + "content": {"application/json": {"example": example_model_config}}, + }, + 400: {"description": "Bad request"}, + 404: {"description": "The model could not be found"}, + }, +) +async def get_model_record( + key: str = Path(description="Key of the model record to fetch."), +) -> AnyModelConfig: + """Get a model record""" + record_store = ApiDependencies.invoker.services.model_manager.store + try: + config: AnyModelConfig = record_store.get_model(key) + return config + except UnknownModelException as e: + raise HTTPException(status_code=404, detail=str(e)) + + +@model_manager_router.get("/summary", operation_id="list_model_summary") +async def list_model_summary( + page: int = Query(default=0, description="The page to get"), + per_page: int = Query(default=10, description="The number of models per page"), + order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"), +) -> PaginatedResults[ModelSummary]: + """Gets a page of model summary data.""" + record_store = ApiDependencies.invoker.services.model_manager.store + results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by) + return results + + +@model_manager_router.get( + "/meta/i/{key}", + operation_id="get_model_metadata", + responses={ + 200: { + "description": "The model metadata was retrieved successfully", + "content": {"application/json": {"example": example_model_metadata}}, + }, + 400: {"description": "Bad request"}, + 404: {"description": "No metadata available"}, + }, +) +async def get_model_metadata( + key: str = Path(description="Key of the model repo metadata to fetch."), +) -> Optional[AnyModelRepoMetadata]: + """Get a model metadata object.""" + record_store = ApiDependencies.invoker.services.model_manager.store + result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key) + if not result: + raise HTTPException(status_code=404, detail="No metadata for a model with this key") + return result + + +@model_manager_router.get( + "/tags", + operation_id="list_tags", +) +async def list_tags() -> Set[str]: + """Get a unique set of all the model tags.""" + record_store = ApiDependencies.invoker.services.model_manager.store + result: Set[str] = record_store.list_tags() + return result + + +@model_manager_router.get( + "/tags/search", + operation_id="search_by_metadata_tags", +) +async def search_by_metadata_tags( + tags: Set[str] = Query(default=None, description="Tags to search for"), +) -> ModelsList: + """Get a list of models.""" + record_store = ApiDependencies.invoker.services.model_manager.store + results = record_store.search_by_metadata_tag(tags) + return ModelsList(models=results) + + +@model_manager_router.patch( + "/i/{key}", + operation_id="update_model_record", + responses={ + 200: { + "description": "The model was updated successfully", + "content": {"application/json": {"example": example_model_config}}, + }, + 400: {"description": "Bad request"}, + 404: {"description": "The model could not be found"}, + 409: {"description": "There is already a model corresponding to the new name"}, + }, + status_code=200, +) +async def update_model_record( + key: Annotated[str, Path(description="Unique key of model")], + info: Annotated[ + AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input) + ], +) -> AnyModelConfig: + """Update model contents with a new config. If the model name or base fields are changed, then the model is renamed.""" + logger = ApiDependencies.invoker.services.logger + record_store = ApiDependencies.invoker.services.model_manager.store + try: + model_response: AnyModelConfig = record_store.update_model(key, config=info) + logger.info(f"Updated model: {key}") + except UnknownModelException as e: + raise HTTPException(status_code=404, detail=str(e)) + except ValueError as e: + logger.error(str(e)) + raise HTTPException(status_code=409, detail=str(e)) + return model_response + + +@model_manager_router.delete( + "/i/{key}", + operation_id="del_model_record", + responses={ + 204: {"description": "Model deleted successfully"}, + 404: {"description": "Model not found"}, + }, + status_code=204, +) +async def del_model_record( + key: str = Path(description="Unique key of model to remove from model registry."), +) -> Response: + """ + Delete model record from database. + + The configuration record will be removed. The corresponding weights files will be + deleted as well if they reside within the InvokeAI "models" directory. + """ + logger = ApiDependencies.invoker.services.logger + + try: + installer = ApiDependencies.invoker.services.model_manager.install + installer.delete(key) + logger.info(f"Deleted model: {key}") + return Response(status_code=204) + except UnknownModelException as e: + logger.error(str(e)) + raise HTTPException(status_code=404, detail=str(e)) + + +@model_manager_router.post( + "/i/", + operation_id="add_model_record", + responses={ + 201: { + "description": "The model added successfully", + "content": {"application/json": {"example": example_model_config}}, + }, + 409: {"description": "There is already a model corresponding to this path or repo_id"}, + 415: {"description": "Unrecognized file/folder format"}, + }, + status_code=201, +) +async def add_model_record( + config: Annotated[ + AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input) + ], +) -> AnyModelConfig: + """Add a model using the configuration information appropriate for its type.""" + logger = ApiDependencies.invoker.services.logger + record_store = ApiDependencies.invoker.services.model_manager.store + if config.key == "": + config.key = sha1(randbytes(100)).hexdigest() + logger.info(f"Created model {config.key} for {config.name}") + try: + record_store.add_model(config.key, config) + except DuplicateModelException as e: + logger.error(str(e)) + raise HTTPException(status_code=409, detail=str(e)) + except InvalidModelException as e: + logger.error(str(e)) + raise HTTPException(status_code=415) + + # now fetch it out + result: AnyModelConfig = record_store.get_model(config.key) + return result + + +@model_manager_router.post( + "/heuristic_import", + operation_id="heuristic_import_model", + responses={ + 201: {"description": "The model imported successfully"}, + 415: {"description": "Unrecognized file/folder format"}, + 424: {"description": "The model appeared to import successfully, but could not be found in the model manager"}, + 409: {"description": "There is already a model corresponding to this path or repo_id"}, + }, + status_code=201, +) +async def heuristic_import( + source: str, + config: Optional[Dict[str, Any]] = Body( + description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ", + default=None, + example={"name": "modelT", "description": "antique cars"}, + ), + access_token: Optional[str] = None, +) -> ModelInstallJob: + """Install a model using a string identifier. + + `source` can be any of the following. + + 1. A path on the local filesystem ('C:\\users\\fred\\model.safetensors') + 2. A Url pointing to a single downloadable model file + 3. A HuggingFace repo_id with any of the following formats: + - model/name + - model/name:fp16:vae + - model/name::vae -- use default precision + - model/name:fp16:path/to/model.safetensors + - model/name::path/to/model.safetensors + + `config` is an optional dict containing model configuration values that will override + the ones that are probed automatically. + + `access_token` is an optional access token for use with Urls that require + authentication. + + Models will be downloaded, probed, configured and installed in a + series of background threads. The return object has `status` attribute + that can be used to monitor progress. + + See the documentation for `import_model_record` for more information on + interpreting the job information returned by this route. + """ + logger = ApiDependencies.invoker.services.logger + + try: + installer = ApiDependencies.invoker.services.model_manager.install + result: ModelInstallJob = installer.heuristic_import( + source=source, + config=config, + ) + logger.info(f"Started installation of {source}") + except UnknownModelException as e: + logger.error(str(e)) + raise HTTPException(status_code=424, detail=str(e)) + except InvalidModelException as e: + logger.error(str(e)) + raise HTTPException(status_code=415) + except ValueError as e: + logger.error(str(e)) + raise HTTPException(status_code=409, detail=str(e)) + return result + + +@model_manager_router.post( + "/install", + operation_id="import_model", + responses={ + 201: {"description": "The model imported successfully"}, + 415: {"description": "Unrecognized file/folder format"}, + 424: {"description": "The model appeared to import successfully, but could not be found in the model manager"}, + 409: {"description": "There is already a model corresponding to this path or repo_id"}, + }, + status_code=201, +) +async def import_model( + source: ModelSource, + config: Optional[Dict[str, Any]] = Body( + description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ", + default=None, + ), +) -> ModelInstallJob: + """Install a model using its local path, repo_id, or remote URL. + + Models will be downloaded, probed, configured and installed in a + series of background threads. The return object has `status` attribute + that can be used to monitor progress. + + The source object is a discriminated Union of LocalModelSource, + HFModelSource and URLModelSource. Set the "type" field to the + appropriate value: + + * To install a local path using LocalModelSource, pass a source of form: + ``` + { + "type": "local", + "path": "/path/to/model", + "inplace": false + } + ``` + The "inplace" flag, if true, will register the model in place in its + current filesystem location. Otherwise, the model will be copied + into the InvokeAI models directory. + + * To install a HuggingFace repo_id using HFModelSource, pass a source of form: + ``` + { + "type": "hf", + "repo_id": "stabilityai/stable-diffusion-2.0", + "variant": "fp16", + "subfolder": "vae", + "access_token": "f5820a918aaf01" + } + ``` + The `variant`, `subfolder` and `access_token` fields are optional. + + * To install a remote model using an arbitrary URL, pass: + ``` + { + "type": "url", + "url": "http://www.civitai.com/models/123456", + "access_token": "f5820a918aaf01" + } + ``` + The `access_token` field is optonal + + The model's configuration record will be probed and filled in + automatically. To override the default guesses, pass "metadata" + with a Dict containing the attributes you wish to override. + + Installation occurs in the background. Either use list_model_install_jobs() + to poll for completion, or listen on the event bus for the following events: + + * "model_install_running" + * "model_install_completed" + * "model_install_error" + + On successful completion, the event's payload will contain the field "key" + containing the installed ID of the model. On an error, the event's payload + will contain the fields "error_type" and "error" describing the nature of the + error and its traceback, respectively. + + """ + logger = ApiDependencies.invoker.services.logger + + try: + installer = ApiDependencies.invoker.services.model_manager.install + result: ModelInstallJob = installer.import_model( + source=source, + config=config, + ) + logger.info(f"Started installation of {source}") + except UnknownModelException as e: + logger.error(str(e)) + raise HTTPException(status_code=424, detail=str(e)) + except InvalidModelException as e: + logger.error(str(e)) + raise HTTPException(status_code=415) + except ValueError as e: + logger.error(str(e)) + raise HTTPException(status_code=409, detail=str(e)) + return result + + +@model_manager_router.get( + "/import", + operation_id="list_model_install_jobs", +) +async def list_model_install_jobs() -> List[ModelInstallJob]: + """Return the list of model install jobs. + + Install jobs have a numeric `id`, a `status`, and other fields that provide information on + the nature of the job and its progress. The `status` is one of: + + * "waiting" -- Job is waiting in the queue to run + * "downloading" -- Model file(s) are downloading + * "running" -- Model has downloaded and the model probing and registration process is running + * "completed" -- Installation completed successfully + * "error" -- An error occurred. Details will be in the "error_type" and "error" fields. + * "cancelled" -- Job was cancelled before completion. + + Once completed, information about the model such as its size, base + model, type, and metadata can be retrieved from the `config_out` + field. For multi-file models such as diffusers, information on individual files + can be retrieved from `download_parts`. + + See the example and schema below for more information. + """ + jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_manager.install.list_jobs() + return jobs + + +@model_manager_router.get( + "/import/{id}", + operation_id="get_model_install_job", + responses={ + 200: {"description": "Success"}, + 404: {"description": "No such job"}, + }, +) +async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob: + """ + Return model install job corresponding to the given source. See the documentation for 'List Model Install Jobs' + for information on the format of the return value. + """ + try: + result: ModelInstallJob = ApiDependencies.invoker.services.model_manager.install.get_job_by_id(id) + return result + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) + + +@model_manager_router.delete( + "/import/{id}", + operation_id="cancel_model_install_job", + responses={ + 201: {"description": "The job was cancelled successfully"}, + 415: {"description": "No such job"}, + }, + status_code=201, +) +async def cancel_model_install_job(id: int = Path(description="Model install job ID")) -> None: + """Cancel the model install job(s) corresponding to the given job ID.""" + installer = ApiDependencies.invoker.services.model_manager.install + try: + job = installer.get_job_by_id(id) + except ValueError as e: + raise HTTPException(status_code=415, detail=str(e)) + installer.cancel_job(job) + + +@model_manager_router.patch( + "/import", + operation_id="prune_model_install_jobs", + responses={ + 204: {"description": "All completed and errored jobs have been pruned"}, + 400: {"description": "Bad request"}, + }, +) +async def prune_model_install_jobs() -> Response: + """Prune all completed and errored jobs from the install job list.""" + ApiDependencies.invoker.services.model_manager.install.prune_jobs() + return Response(status_code=204) + + +@model_manager_router.patch( + "/sync", + operation_id="sync_models_to_config", + responses={ + 204: {"description": "Model config record database resynced with files on disk"}, + 400: {"description": "Bad request"}, + }, +) +async def sync_models_to_config() -> Response: + """ + Traverse the models and autoimport directories. + + Model files without a corresponding + record in the database are added. Orphan records without a models file are deleted. + """ + ApiDependencies.invoker.services.model_manager.install.sync_to_config() + return Response(status_code=204) + + +@model_manager_router.put( + "/convert/{key}", + operation_id="convert_model", + responses={ + 200: { + "description": "Model converted successfully", + "content": {"application/json": {"example": example_model_config}}, + }, + 400: {"description": "Bad request"}, + 404: {"description": "Model not found"}, + 409: {"description": "There is already a model registered at this location"}, + }, +) +async def convert_model( + key: str = Path(description="Unique key of the safetensors main model to convert to diffusers format."), +) -> AnyModelConfig: + """ + Permanently convert a model into diffusers format, replacing the safetensors version. + Note that during the conversion process the key and model hash will change. + The return value is the model configuration for the converted model. + """ + logger = ApiDependencies.invoker.services.logger + loader = ApiDependencies.invoker.services.model_manager.load + store = ApiDependencies.invoker.services.model_manager.store + installer = ApiDependencies.invoker.services.model_manager.install + + try: + model_config = store.get_model(key) + except UnknownModelException as e: + logger.error(str(e)) + raise HTTPException(status_code=424, detail=str(e)) + + if not isinstance(model_config, MainCheckpointConfig): + logger.error(f"The model with key {key} is not a main checkpoint model.") + raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.") + + # loading the model will convert it into a cached diffusers file + loader.load_model_by_config(model_config, submodel_type=SubModelType.Scheduler) + + # Get the path of the converted model from the loader + cache_path = loader.convert_cache.cache_path(key) + assert cache_path.exists() + + # temporarily rename the original safetensors file so that there is no naming conflict + original_name = model_config.name + model_config.name = f"{original_name}.DELETE" + store.update_model(key, config=model_config) + + # install the diffusers + try: + new_key = installer.install_path( + cache_path, + config={ + "name": original_name, + "description": model_config.description, + "original_hash": model_config.original_hash, + "source": model_config.source, + }, + ) + except DuplicateModelException as e: + logger.error(str(e)) + raise HTTPException(status_code=409, detail=str(e)) + + # get the original metadata + if orig_metadata := store.get_metadata(key): + store.metadata_store.add_metadata(new_key, orig_metadata) + + # delete the original safetensors file + installer.delete(key) + + # delete the cached version + shutil.rmtree(cache_path) + + # return the config record for the new diffusers directory + new_config: AnyModelConfig = store.get_model(new_key) + return new_config + + +@model_manager_router.put( + "/merge", + operation_id="merge", + responses={ + 200: { + "description": "Model converted successfully", + "content": {"application/json": {"example": example_model_config}}, + }, + 400: {"description": "Bad request"}, + 404: {"description": "Model not found"}, + 409: {"description": "There is already a model registered at this location"}, + }, +) +async def merge( + keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3), + merged_model_name: Optional[str] = Body(description="Name of destination model", default=None), + alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5), + force: bool = Body( + description="Force merging of models created with different versions of diffusers", + default=False, + ), + interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None), + merge_dest_directory: Optional[str] = Body( + description="Save the merged model to the designated directory (with 'merged_model_name' appended)", + default=None, + ), +) -> AnyModelConfig: + """ + Merge diffusers models. The process is controlled by a set parameters provided in the body of the request. + ``` + Argument Description [default] + -------- ---------------------- + keys List of 2-3 model keys to merge together. All models must use the same base type. + merged_model_name Name for the merged model [Concat model names] + alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5] + force If true, force the merge even if the models were generated by different versions of the diffusers library [False] + interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum] + merge_dest_directory Specify a directory to store the merged model in [models directory] + ``` + """ + logger = ApiDependencies.invoker.services.logger + try: + logger.info(f"Merging models: {keys} into {merge_dest_directory or ''}/{merged_model_name}") + dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None + installer = ApiDependencies.invoker.services.model_manager.install + merger = ModelMerger(installer) + model_names = [installer.record_store.get_model(x).name for x in keys] + response = merger.merge_diffusion_models_and_save( + model_keys=keys, + merged_model_name=merged_model_name or "+".join(model_names), + alpha=alpha, + interp=interp, + force=force, + merge_dest_directory=dest, + ) + except UnknownModelException: + raise HTTPException( + status_code=404, + detail=f"One or more of the models '{keys}' not found", + ) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + return response diff --git a/invokeai/app/api/routers/model_records.py b/invokeai/app/api/routers/model_records.py deleted file mode 100644 index f9a3e408985..00000000000 --- a/invokeai/app/api/routers/model_records.py +++ /dev/null @@ -1,472 +0,0 @@ -# Copyright (c) 2023 Lincoln D. Stein -"""FastAPI route for model configuration records.""" - -import pathlib -from hashlib import sha1 -from random import randbytes -from typing import Any, Dict, List, Optional, Set - -from fastapi import Body, Path, Query, Response -from fastapi.routing import APIRouter -from pydantic import BaseModel, ConfigDict -from starlette.exceptions import HTTPException -from typing_extensions import Annotated - -from invokeai.app.services.model_install import ModelInstallJob, ModelSource -from invokeai.app.services.model_records import ( - DuplicateModelException, - InvalidModelException, - ModelRecordOrderBy, - ModelSummary, - UnknownModelException, -) -from invokeai.app.services.shared.pagination import PaginatedResults -from invokeai.backend.model_manager.config import ( - AnyModelConfig, - BaseModelType, - ModelFormat, - ModelType, -) -from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger -from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata - -from ..dependencies import ApiDependencies - -model_records_router = APIRouter(prefix="/v1/model/record", tags=["model_manager_v2_unstable"]) - - -class ModelsList(BaseModel): - """Return list of configs.""" - - models: List[AnyModelConfig] - - model_config = ConfigDict(use_enum_values=True) - - -class ModelTagSet(BaseModel): - """Return tags for a set of models.""" - - key: str - name: str - author: str - tags: Set[str] - - -@model_records_router.get( - "/", - operation_id="list_model_records", -) -async def list_model_records( - base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"), - model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"), - model_name: Optional[str] = Query(default=None, description="Exact match on the name of the model"), - model_format: Optional[ModelFormat] = Query( - default=None, description="Exact match on the format of the model (e.g. 'diffusers')" - ), -) -> ModelsList: - """Get a list of models.""" - record_store = ApiDependencies.invoker.services.model_records - found_models: list[AnyModelConfig] = [] - if base_models: - for base_model in base_models: - found_models.extend( - record_store.search_by_attr( - base_model=base_model, model_type=model_type, model_name=model_name, model_format=model_format - ) - ) - else: - found_models.extend( - record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format) - ) - return ModelsList(models=found_models) - - -@model_records_router.get( - "/i/{key}", - operation_id="get_model_record", - responses={ - 200: {"description": "Success"}, - 400: {"description": "Bad request"}, - 404: {"description": "The model could not be found"}, - }, -) -async def get_model_record( - key: str = Path(description="Key of the model record to fetch."), -) -> AnyModelConfig: - """Get a model record""" - record_store = ApiDependencies.invoker.services.model_records - try: - return record_store.get_model(key) - except UnknownModelException as e: - raise HTTPException(status_code=404, detail=str(e)) - - -@model_records_router.get("/meta", operation_id="list_model_summary") -async def list_model_summary( - page: int = Query(default=0, description="The page to get"), - per_page: int = Query(default=10, description="The number of models per page"), - order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"), -) -> PaginatedResults[ModelSummary]: - """Gets a page of model summary data.""" - return ApiDependencies.invoker.services.model_records.list_models(page=page, per_page=per_page, order_by=order_by) - - -@model_records_router.get( - "/meta/i/{key}", - operation_id="get_model_metadata", - responses={ - 200: {"description": "Success"}, - 400: {"description": "Bad request"}, - 404: {"description": "No metadata available"}, - }, -) -async def get_model_metadata( - key: str = Path(description="Key of the model repo metadata to fetch."), -) -> Optional[AnyModelRepoMetadata]: - """Get a model metadata object.""" - record_store = ApiDependencies.invoker.services.model_records - result = record_store.get_metadata(key) - if not result: - raise HTTPException(status_code=404, detail="No metadata for a model with this key") - return result - - -@model_records_router.get( - "/tags", - operation_id="list_tags", -) -async def list_tags() -> Set[str]: - """Get a unique set of all the model tags.""" - record_store = ApiDependencies.invoker.services.model_records - return record_store.list_tags() - - -@model_records_router.get( - "/tags/search", - operation_id="search_by_metadata_tags", -) -async def search_by_metadata_tags( - tags: Set[str] = Query(default=None, description="Tags to search for"), -) -> ModelsList: - """Get a list of models.""" - record_store = ApiDependencies.invoker.services.model_records - results = record_store.search_by_metadata_tag(tags) - return ModelsList(models=results) - - -@model_records_router.patch( - "/i/{key}", - operation_id="update_model_record", - responses={ - 200: {"description": "The model was updated successfully"}, - 400: {"description": "Bad request"}, - 404: {"description": "The model could not be found"}, - 409: {"description": "There is already a model corresponding to the new name"}, - }, - status_code=200, - response_model=AnyModelConfig, -) -async def update_model_record( - key: Annotated[str, Path(description="Unique key of model")], - info: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")], -) -> AnyModelConfig: - """Update model contents with a new config. If the model name or base fields are changed, then the model is renamed.""" - logger = ApiDependencies.invoker.services.logger - record_store = ApiDependencies.invoker.services.model_records - try: - model_response = record_store.update_model(key, config=info) - logger.info(f"Updated model: {key}") - except UnknownModelException as e: - raise HTTPException(status_code=404, detail=str(e)) - except ValueError as e: - logger.error(str(e)) - raise HTTPException(status_code=409, detail=str(e)) - return model_response - - -@model_records_router.delete( - "/i/{key}", - operation_id="del_model_record", - responses={ - 204: {"description": "Model deleted successfully"}, - 404: {"description": "Model not found"}, - }, - status_code=204, -) -async def del_model_record( - key: str = Path(description="Unique key of model to remove from model registry."), -) -> Response: - """ - Delete model record from database. - - The configuration record will be removed. The corresponding weights files will be - deleted as well if they reside within the InvokeAI "models" directory. - """ - logger = ApiDependencies.invoker.services.logger - - try: - installer = ApiDependencies.invoker.services.model_install - installer.delete(key) - logger.info(f"Deleted model: {key}") - return Response(status_code=204) - except UnknownModelException as e: - logger.error(str(e)) - raise HTTPException(status_code=404, detail=str(e)) - - -@model_records_router.post( - "/i/", - operation_id="add_model_record", - responses={ - 201: {"description": "The model added successfully"}, - 409: {"description": "There is already a model corresponding to this path or repo_id"}, - 415: {"description": "Unrecognized file/folder format"}, - }, - status_code=201, -) -async def add_model_record( - config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")], -) -> AnyModelConfig: - """Add a model using the configuration information appropriate for its type.""" - logger = ApiDependencies.invoker.services.logger - record_store = ApiDependencies.invoker.services.model_records - if config.key == "": - config.key = sha1(randbytes(100)).hexdigest() - logger.info(f"Created model {config.key} for {config.name}") - try: - record_store.add_model(config.key, config) - except DuplicateModelException as e: - logger.error(str(e)) - raise HTTPException(status_code=409, detail=str(e)) - except InvalidModelException as e: - logger.error(str(e)) - raise HTTPException(status_code=415) - - # now fetch it out - return record_store.get_model(config.key) - - -@model_records_router.post( - "/import", - operation_id="import_model_record", - responses={ - 201: {"description": "The model imported successfully"}, - 415: {"description": "Unrecognized file/folder format"}, - 424: {"description": "The model appeared to import successfully, but could not be found in the model manager"}, - 409: {"description": "There is already a model corresponding to this path or repo_id"}, - }, - status_code=201, -) -async def import_model( - source: ModelSource, - config: Optional[Dict[str, Any]] = Body( - description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ", - default=None, - ), -) -> ModelInstallJob: - """Add a model using its local path, repo_id, or remote URL. - - Models will be downloaded, probed, configured and installed in a - series of background threads. The return object has `status` attribute - that can be used to monitor progress. - - The source object is a discriminated Union of LocalModelSource, - HFModelSource and URLModelSource. Set the "type" field to the - appropriate value: - - * To install a local path using LocalModelSource, pass a source of form: - `{ - "type": "local", - "path": "/path/to/model", - "inplace": false - }` - The "inplace" flag, if true, will register the model in place in its - current filesystem location. Otherwise, the model will be copied - into the InvokeAI models directory. - - * To install a HuggingFace repo_id using HFModelSource, pass a source of form: - `{ - "type": "hf", - "repo_id": "stabilityai/stable-diffusion-2.0", - "variant": "fp16", - "subfolder": "vae", - "access_token": "f5820a918aaf01" - }` - The `variant`, `subfolder` and `access_token` fields are optional. - - * To install a remote model using an arbitrary URL, pass: - `{ - "type": "url", - "url": "http://www.civitai.com/models/123456", - "access_token": "f5820a918aaf01" - }` - The `access_token` field is optonal - - The model's configuration record will be probed and filled in - automatically. To override the default guesses, pass "metadata" - with a Dict containing the attributes you wish to override. - - Installation occurs in the background. Either use list_model_install_jobs() - to poll for completion, or listen on the event bus for the following events: - - "model_install_running" - "model_install_completed" - "model_install_error" - - On successful completion, the event's payload will contain the field "key" - containing the installed ID of the model. On an error, the event's payload - will contain the fields "error_type" and "error" describing the nature of the - error and its traceback, respectively. - - """ - logger = ApiDependencies.invoker.services.logger - - try: - installer = ApiDependencies.invoker.services.model_install - result: ModelInstallJob = installer.import_model( - source=source, - config=config, - ) - logger.info(f"Started installation of {source}") - except UnknownModelException as e: - logger.error(str(e)) - raise HTTPException(status_code=424, detail=str(e)) - except InvalidModelException as e: - logger.error(str(e)) - raise HTTPException(status_code=415) - except ValueError as e: - logger.error(str(e)) - raise HTTPException(status_code=409, detail=str(e)) - return result - - -@model_records_router.get( - "/import", - operation_id="list_model_install_jobs", -) -async def list_model_install_jobs() -> List[ModelInstallJob]: - """Return list of model install jobs.""" - jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_install.list_jobs() - return jobs - - -@model_records_router.get( - "/import/{id}", - operation_id="get_model_install_job", - responses={ - 200: {"description": "Success"}, - 404: {"description": "No such job"}, - }, -) -async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob: - """Return model install job corresponding to the given source.""" - try: - return ApiDependencies.invoker.services.model_install.get_job_by_id(id) - except ValueError as e: - raise HTTPException(status_code=404, detail=str(e)) - - -@model_records_router.delete( - "/import/{id}", - operation_id="cancel_model_install_job", - responses={ - 201: {"description": "The job was cancelled successfully"}, - 415: {"description": "No such job"}, - }, - status_code=201, -) -async def cancel_model_install_job(id: int = Path(description="Model install job ID")) -> None: - """Cancel the model install job(s) corresponding to the given job ID.""" - installer = ApiDependencies.invoker.services.model_install - try: - job = installer.get_job_by_id(id) - except ValueError as e: - raise HTTPException(status_code=415, detail=str(e)) - installer.cancel_job(job) - - -@model_records_router.patch( - "/import", - operation_id="prune_model_install_jobs", - responses={ - 204: {"description": "All completed and errored jobs have been pruned"}, - 400: {"description": "Bad request"}, - }, -) -async def prune_model_install_jobs() -> Response: - """Prune all completed and errored jobs from the install job list.""" - ApiDependencies.invoker.services.model_install.prune_jobs() - return Response(status_code=204) - - -@model_records_router.patch( - "/sync", - operation_id="sync_models_to_config", - responses={ - 204: {"description": "Model config record database resynced with files on disk"}, - 400: {"description": "Bad request"}, - }, -) -async def sync_models_to_config() -> Response: - """ - Traverse the models and autoimport directories. - - Model files without a corresponding - record in the database are added. Orphan records without a models file are deleted. - """ - ApiDependencies.invoker.services.model_install.sync_to_config() - return Response(status_code=204) - - -@model_records_router.put( - "/merge", - operation_id="merge", -) -async def merge( - keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3), - merged_model_name: Optional[str] = Body(description="Name of destination model", default=None), - alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5), - force: bool = Body( - description="Force merging of models created with different versions of diffusers", - default=False, - ), - interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None), - merge_dest_directory: Optional[str] = Body( - description="Save the merged model to the designated directory (with 'merged_model_name' appended)", - default=None, - ), -) -> AnyModelConfig: - """ - Merge diffusers models. - - keys: List of 2-3 model keys to merge together. All models must use the same base type. - merged_model_name: Name for the merged model [Concat model names] - alpha: Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5] - force: If true, force the merge even if the models were generated by different versions of the diffusers library [False] - interp: Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum] - merge_dest_directory: Specify a directory to store the merged model in [models directory] - """ - print(f"here i am, keys={keys}") - logger = ApiDependencies.invoker.services.logger - try: - logger.info(f"Merging models: {keys} into {merge_dest_directory or ''}/{merged_model_name}") - dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None - installer = ApiDependencies.invoker.services.model_install - merger = ModelMerger(installer) - model_names = [installer.record_store.get_model(x).name for x in keys] - response = merger.merge_diffusion_models_and_save( - model_keys=keys, - merged_model_name=merged_model_name or "+".join(model_names), - alpha=alpha, - interp=interp, - force=force, - merge_dest_directory=dest, - ) - except UnknownModelException: - raise HTTPException( - status_code=404, - detail=f"One or more of the models '{keys}' not found", - ) - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) - return response diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py deleted file mode 100644 index 8f83820cf89..00000000000 --- a/invokeai/app/api/routers/models.py +++ /dev/null @@ -1,427 +0,0 @@ -# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2023 Lincoln D. Stein - -import pathlib -from typing import Annotated, List, Literal, Optional, Union - -from fastapi import Body, Path, Query, Response -from fastapi.routing import APIRouter -from pydantic import BaseModel, ConfigDict, Field, TypeAdapter -from starlette.exceptions import HTTPException - -from invokeai.backend import BaseModelType, ModelType -from invokeai.backend.model_management import MergeInterpolationMethod -from invokeai.backend.model_management.models import ( - OPENAPI_MODEL_CONFIGS, - InvalidModelException, - ModelNotFoundException, - SchedulerPredictionType, -) - -from ..dependencies import ApiDependencies - -models_router = APIRouter(prefix="/v1/models", tags=["models"]) - -UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] -UpdateModelResponseValidator = TypeAdapter(UpdateModelResponse) - -ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] -ImportModelResponseValidator = TypeAdapter(ImportModelResponse) - -ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] -ConvertModelResponseValidator = TypeAdapter(ConvertModelResponse) - -MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] -ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)] - - -class ModelsList(BaseModel): - models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]] - - model_config = ConfigDict(use_enum_values=True) - - -ModelsListValidator = TypeAdapter(ModelsList) - - -@models_router.get( - "/", - operation_id="list_models", - responses={200: {"model": ModelsList}}, -) -async def list_models( - base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"), - model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"), -) -> ModelsList: - """Gets a list of models""" - if base_models and len(base_models) > 0: - models_raw = [] - for base_model in base_models: - models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type)) - else: - models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type) - models = ModelsListValidator.validate_python({"models": models_raw}) - return models - - -@models_router.patch( - "/{base_model}/{model_type}/{model_name}", - operation_id="update_model", - responses={ - 200: {"description": "The model was updated successfully"}, - 400: {"description": "Bad request"}, - 404: {"description": "The model could not be found"}, - 409: {"description": "There is already a model corresponding to the new name"}, - }, - status_code=200, - response_model=UpdateModelResponse, -) -async def update_model( - base_model: BaseModelType = Path(description="Base model"), - model_type: ModelType = Path(description="The type of model"), - model_name: str = Path(description="model name"), - info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"), -) -> UpdateModelResponse: - """Update model contents with a new config. If the model name or base fields are changed, then the model is renamed.""" - logger = ApiDependencies.invoker.services.logger - - try: - previous_info = ApiDependencies.invoker.services.model_manager.list_model( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ) - - # rename operation requested - if info.model_name != model_name or info.base_model != base_model: - ApiDependencies.invoker.services.model_manager.rename_model( - base_model=base_model, - model_type=model_type, - model_name=model_name, - new_name=info.model_name, - new_base=info.base_model, - ) - logger.info(f"Successfully renamed {base_model.value}/{model_name}=>{info.base_model}/{info.model_name}") - # update information to support an update of attributes - model_name = info.model_name - base_model = info.base_model - new_info = ApiDependencies.invoker.services.model_manager.list_model( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ) - if new_info.get("path") != previous_info.get( - "path" - ): # model manager moved model path during rename - don't overwrite it - info.path = new_info.get("path") - - # replace empty string values with None/null to avoid phenomenon of vae: '' - info_dict = info.model_dump() - info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()} - - ApiDependencies.invoker.services.model_manager.update_model( - model_name=model_name, - base_model=base_model, - model_type=model_type, - model_attributes=info_dict, - ) - - model_raw = ApiDependencies.invoker.services.model_manager.list_model( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ) - model_response = UpdateModelResponseValidator.validate_python(model_raw) - except ModelNotFoundException as e: - raise HTTPException(status_code=404, detail=str(e)) - except ValueError as e: - logger.error(str(e)) - raise HTTPException(status_code=409, detail=str(e)) - except Exception as e: - logger.error(str(e)) - raise HTTPException(status_code=400, detail=str(e)) - - return model_response - - -@models_router.post( - "/import", - operation_id="import_model", - responses={ - 201: {"description": "The model imported successfully"}, - 404: {"description": "The model could not be found"}, - 415: {"description": "Unrecognized file/folder format"}, - 424: {"description": "The model appeared to import successfully, but could not be found in the model manager"}, - 409: {"description": "There is already a model corresponding to this path or repo_id"}, - }, - status_code=201, - response_model=ImportModelResponse, -) -async def import_model( - location: str = Body(description="A model path, repo_id or URL to import"), - prediction_type: Optional[Literal["v_prediction", "epsilon", "sample"]] = Body( - description="Prediction type for SDv2 checkpoints and rare SDv1 checkpoints", - default=None, - ), -) -> ImportModelResponse: - """Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically""" - - location = location.strip("\"' ") - items_to_import = {location} - prediction_types = {x.value: x for x in SchedulerPredictionType} - logger = ApiDependencies.invoker.services.logger - - try: - installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import( - items_to_import=items_to_import, - prediction_type_helper=lambda x: prediction_types.get(prediction_type), - ) - info = installed_models.get(location) - - if not info: - logger.error("Import failed") - raise HTTPException(status_code=415) - - logger.info(f"Successfully imported {location}, got {info}") - model_raw = ApiDependencies.invoker.services.model_manager.list_model( - model_name=info.name, base_model=info.base_model, model_type=info.model_type - ) - return ImportModelResponseValidator.validate_python(model_raw) - - except ModelNotFoundException as e: - logger.error(str(e)) - raise HTTPException(status_code=404, detail=str(e)) - except InvalidModelException as e: - logger.error(str(e)) - raise HTTPException(status_code=415) - except ValueError as e: - logger.error(str(e)) - raise HTTPException(status_code=409, detail=str(e)) - - -@models_router.post( - "/add", - operation_id="add_model", - responses={ - 201: {"description": "The model added successfully"}, - 404: {"description": "The model could not be found"}, - 424: {"description": "The model appeared to add successfully, but could not be found in the model manager"}, - 409: {"description": "There is already a model corresponding to this path or repo_id"}, - }, - status_code=201, - response_model=ImportModelResponse, -) -async def add_model( - info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"), -) -> ImportModelResponse: - """Add a model using the configuration information appropriate for its type. Only local models can be added by path""" - - logger = ApiDependencies.invoker.services.logger - - try: - ApiDependencies.invoker.services.model_manager.add_model( - info.model_name, - info.base_model, - info.model_type, - model_attributes=info.model_dump(), - ) - logger.info(f"Successfully added {info.model_name}") - model_raw = ApiDependencies.invoker.services.model_manager.list_model( - model_name=info.model_name, - base_model=info.base_model, - model_type=info.model_type, - ) - return ImportModelResponseValidator.validate_python(model_raw) - except ModelNotFoundException as e: - logger.error(str(e)) - raise HTTPException(status_code=404, detail=str(e)) - except ValueError as e: - logger.error(str(e)) - raise HTTPException(status_code=409, detail=str(e)) - - -@models_router.delete( - "/{base_model}/{model_type}/{model_name}", - operation_id="del_model", - responses={ - 204: {"description": "Model deleted successfully"}, - 404: {"description": "Model not found"}, - }, - status_code=204, - response_model=None, -) -async def delete_model( - base_model: BaseModelType = Path(description="Base model"), - model_type: ModelType = Path(description="The type of model"), - model_name: str = Path(description="model name"), -) -> Response: - """Delete Model""" - logger = ApiDependencies.invoker.services.logger - - try: - ApiDependencies.invoker.services.model_manager.del_model( - model_name, base_model=base_model, model_type=model_type - ) - logger.info(f"Deleted model: {model_name}") - return Response(status_code=204) - except ModelNotFoundException as e: - logger.error(str(e)) - raise HTTPException(status_code=404, detail=str(e)) - - -@models_router.put( - "/convert/{base_model}/{model_type}/{model_name}", - operation_id="convert_model", - responses={ - 200: {"description": "Model converted successfully"}, - 400: {"description": "Bad request"}, - 404: {"description": "Model not found"}, - }, - status_code=200, - response_model=ConvertModelResponse, -) -async def convert_model( - base_model: BaseModelType = Path(description="Base model"), - model_type: ModelType = Path(description="The type of model"), - model_name: str = Path(description="model name"), - convert_dest_directory: Optional[str] = Query( - default=None, description="Save the converted model to the designated directory" - ), -) -> ConvertModelResponse: - """Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none.""" - logger = ApiDependencies.invoker.services.logger - try: - logger.info(f"Converting model: {model_name}") - dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None - ApiDependencies.invoker.services.model_manager.convert_model( - model_name, - base_model=base_model, - model_type=model_type, - convert_dest_directory=dest, - ) - model_raw = ApiDependencies.invoker.services.model_manager.list_model( - model_name, base_model=base_model, model_type=model_type - ) - response = ConvertModelResponseValidator.validate_python(model_raw) - except ModelNotFoundException as e: - raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}") - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) - return response - - -@models_router.get( - "/search", - operation_id="search_for_models", - responses={ - 200: {"description": "Directory searched successfully"}, - 404: {"description": "Invalid directory path"}, - }, - status_code=200, - response_model=List[pathlib.Path], -) -async def search_for_models( - search_path: pathlib.Path = Query(description="Directory path to search for models"), -) -> List[pathlib.Path]: - if not search_path.is_dir(): - raise HTTPException( - status_code=404, - detail=f"The search path '{search_path}' does not exist or is not directory", - ) - return ApiDependencies.invoker.services.model_manager.search_for_models(search_path) - - -@models_router.get( - "/ckpt_confs", - operation_id="list_ckpt_configs", - responses={ - 200: {"description": "paths retrieved successfully"}, - }, - status_code=200, - response_model=List[pathlib.Path], -) -async def list_ckpt_configs() -> List[pathlib.Path]: - """Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT.""" - return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs() - - -@models_router.post( - "/sync", - operation_id="sync_to_config", - responses={ - 201: {"description": "synchronization successful"}, - }, - status_code=201, - response_model=bool, -) -async def sync_to_config() -> bool: - """Call after making changes to models.yaml, autoimport directories or models directory to synchronize - in-memory data structures with disk data structures.""" - ApiDependencies.invoker.services.model_manager.sync_to_config() - return True - - -# There's some weird pydantic-fastapi behaviour that requires this to be a separate class -# TODO: After a few updates, see if it works inside the route operation handler? -class MergeModelsBody(BaseModel): - model_names: List[str] = Field(description="model name", min_length=2, max_length=3) - merged_model_name: Optional[str] = Field(description="Name of destination model") - alpha: Optional[float] = Field(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5) - interp: Optional[MergeInterpolationMethod] = Field(description="Interpolation method") - force: Optional[bool] = Field( - description="Force merging of models created with different versions of diffusers", - default=False, - ) - - merge_dest_directory: Optional[str] = Field( - description="Save the merged model to the designated directory (with 'merged_model_name' appended)", - default=None, - ) - - model_config = ConfigDict(protected_namespaces=()) - - -@models_router.put( - "/merge/{base_model}", - operation_id="merge_models", - responses={ - 200: {"description": "Model converted successfully"}, - 400: {"description": "Incompatible models"}, - 404: {"description": "One or more models not found"}, - }, - status_code=200, - response_model=MergeModelResponse, -) -async def merge_models( - body: Annotated[MergeModelsBody, Body(description="Model configuration", embed=True)], - base_model: BaseModelType = Path(description="Base model"), -) -> MergeModelResponse: - """Convert a checkpoint model into a diffusers model""" - logger = ApiDependencies.invoker.services.logger - try: - logger.info( - f"Merging models: {body.model_names} into {body.merge_dest_directory or ''}/{body.merged_model_name}" - ) - dest = pathlib.Path(body.merge_dest_directory) if body.merge_dest_directory else None - result = ApiDependencies.invoker.services.model_manager.merge_models( - model_names=body.model_names, - base_model=base_model, - merged_model_name=body.merged_model_name or "+".join(body.model_names), - alpha=body.alpha, - interp=body.interp, - force=body.force, - merge_dest_directory=dest, - ) - model_raw = ApiDependencies.invoker.services.model_manager.list_model( - result.name, - base_model=base_model, - model_type=ModelType.Main, - ) - response = ConvertModelResponseValidator.validate_python(model_raw) - except ModelNotFoundException: - raise HTTPException( - status_code=404, - detail=f"One or more of the models '{body.model_names}' not found", - ) - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) - return response diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 6294083d0e1..e951f16d7c2 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -47,8 +47,7 @@ boards, download_queue, images, - model_records, - models, + model_manager, session_queue, sessions, utilities, @@ -115,8 +114,7 @@ async def shutdown_event() -> None: app.include_router(sessions.session_router, prefix="/api") app.include_router(utilities.utilities_router, prefix="/api") -app.include_router(models.models_router, prefix="/api") -app.include_router(model_records.model_records_router, prefix="/api") +app.include_router(model_manager.model_manager_router, prefix="/api") app.include_router(download_queue.download_queue_router, prefix="/api") app.include_router(images.images_router, prefix="/api") app.include_router(boards.boards_router, prefix="/api") @@ -178,21 +176,23 @@ def custom_openapi() -> dict[str, Any]: invoker_schema["class"] = "invocation" openapi_schema["components"]["schemas"][f"{output_type_title}"]["class"] = "output" - from invokeai.backend.model_management.models import get_model_config_enums - - for model_config_format_enum in set(get_model_config_enums()): - name = model_config_format_enum.__qualname__ - - if name in openapi_schema["components"]["schemas"]: - # print(f"Config with name {name} already defined") - continue - - openapi_schema["components"]["schemas"][name] = { - "title": name, - "description": "An enumeration.", - "type": "string", - "enum": [v.value for v in model_config_format_enum], - } + # This code no longer seems to be necessary? + # Leave it here just in case + # + # from invokeai.backend.model_manager import get_model_config_formats + # formats = get_model_config_formats() + # for model_config_name, enum_set in formats.items(): + + # if model_config_name in openapi_schema["components"]["schemas"]: + # # print(f"Config with name {name} already defined") + # continue + + # openapi_schema["components"]["schemas"][model_config_name] = { + # "title": model_config_name, + # "description": "An enumeration.", + # "type": "string", + # "enum": [v.value for v in enum_set], + # } app.openapi_schema = openapi_schema return app.openapi_schema diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 49c62cff564..b5e96b80aee 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -1,22 +1,27 @@ from dataclasses import dataclass -from typing import List, Optional, Union +from typing import Iterator, List, Optional, Tuple, Union import torch from compel import Compel, ReturnedEmbeddingsType from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment +from transformers import CLIPTokenizer +import invokeai.backend.util.logging as logger from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput +from invokeai.app.services.model_records import UnknownModelException from invokeai.app.shared.fields import FieldDescriptions +from invokeai.app.util.ti_utils import extract_ti_triggers_from_prompt +from invokeai.backend.lora import LoRAModelRaw +from invokeai.backend.model_manager import ModelType +from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, ExtraConditioningInfo, SDXLConditioningInfo, ) +from invokeai.backend.textual_inversion import TextualInversionModelRaw +from invokeai.backend.util.devices import torch_dtype -from ...backend.model_management.lora import ModelPatcher -from ...backend.model_management.models import ModelNotFoundException, ModelType -from ...backend.util.devices import torch_dtype -from ..util.ti_utils import extract_ti_triggers_from_prompt from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, @@ -66,21 +71,22 @@ class CompelInvocation(BaseInvocation): @torch.no_grad() def invoke(self, context: InvocationContext) -> ConditioningOutput: - tokenizer_info = context.services.model_manager.get_model( + tokenizer_info = context.services.model_manager.load_model_by_key( **self.clip.tokenizer.model_dump(), context=context, ) - text_encoder_info = context.services.model_manager.get_model( + text_encoder_info = context.services.model_manager.load_model_by_key( **self.clip.text_encoder.model_dump(), context=context, ) - def _lora_loader(): + def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in self.clip.loras: - lora_info = context.services.model_manager.get_model( + lora_info = context.services.model_manager.load_model_by_key( **lora.model_dump(exclude={"weight"}), context=context ) - yield (lora_info.context.model, lora.weight) + assert isinstance(lora_info.model, LoRAModelRaw) + yield (lora_info.model, lora.weight) del lora_info return @@ -90,25 +96,20 @@ def _lora_loader(): for trigger in extract_ti_triggers_from_prompt(self.prompt): name = trigger[1:-1] try: - ti_list.append( - ( - name, - context.services.model_manager.get_model( - model_name=name, - base_model=self.clip.text_encoder.base_model, - model_type=ModelType.TextualInversion, - context=context, - ).context.model, - ) - ) - except ModelNotFoundException: + loaded_model = context.services.model_manager.load_model_by_key( + **self.clip.text_encoder.model_dump(), + context=context, + ).model + assert isinstance(loaded_model, TextualInversionModelRaw) + ti_list.append((name, loaded_model)) + except UnknownModelException: # print(e) # import traceback # print(traceback.format_exc()) print(f'Warn: trigger: "{trigger}" not found') with ( - ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as ( + ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as ( tokenizer, ti_manager, ), @@ -116,7 +117,7 @@ def _lora_loader(): # Apply the LoRA after text_encoder has been moved to its target device for faster patching. ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. - ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers), + ModelPatcher.apply_clip_skip(text_encoder_info.model, self.clip.skipped_layers), ): compel = Compel( tokenizer=tokenizer, @@ -150,7 +151,7 @@ def _lora_loader(): ) conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" - context.services.latents.save(conditioning_name, conditioning_data) + context.services.latents.save(conditioning_name, conditioning_data) # TODO: fix type mismatch here return ConditioningOutput( conditioning=ConditioningField( @@ -160,6 +161,8 @@ def _lora_loader(): class SDXLPromptInvocationBase: + """Prompt processor for SDXL models.""" + def run_clip_compel( self, context: InvocationContext, @@ -168,26 +171,27 @@ def run_clip_compel( get_pooled: bool, lora_prefix: str, zero_on_empty: bool, - ): - tokenizer_info = context.services.model_manager.get_model( + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]: + tokenizer_info = context.services.model_manager.load_model_by_key( **clip_field.tokenizer.model_dump(), context=context, ) - text_encoder_info = context.services.model_manager.get_model( + text_encoder_info = context.services.model_manager.load_model_by_key( **clip_field.text_encoder.model_dump(), context=context, ) # return zero on empty if prompt == "" and zero_on_empty: - cpu_text_encoder = text_encoder_info.context.model + cpu_text_encoder = text_encoder_info.model + assert isinstance(cpu_text_encoder, torch.nn.Module) c = torch.zeros( ( 1, cpu_text_encoder.config.max_position_embeddings, cpu_text_encoder.config.hidden_size, ), - dtype=text_encoder_info.context.cache.precision, + dtype=cpu_text_encoder.dtype, ) if get_pooled: c_pooled = torch.zeros( @@ -198,12 +202,14 @@ def run_clip_compel( c_pooled = None return c, c_pooled, None - def _lora_loader(): + def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in clip_field.loras: - lora_info = context.services.model_manager.get_model( + lora_info = context.services.model_manager.load_model_by_key( **lora.model_dump(exclude={"weight"}), context=context ) - yield (lora_info.context.model, lora.weight) + lora_model = lora_info.model + assert isinstance(lora_model, LoRAModelRaw) + yield (lora_model, lora.weight) del lora_info return @@ -213,25 +219,24 @@ def _lora_loader(): for trigger in extract_ti_triggers_from_prompt(prompt): name = trigger[1:-1] try: - ti_list.append( - ( - name, - context.services.model_manager.get_model( - model_name=name, - base_model=clip_field.text_encoder.base_model, - model_type=ModelType.TextualInversion, - context=context, - ).context.model, - ) - ) - except ModelNotFoundException: + ti_model = context.services.model_manager.load_model_by_attr( + model_name=name, + base_model=text_encoder_info.config.base, + model_type=ModelType.TextualInversion, + context=context, + ).model + assert isinstance(ti_model, TextualInversionModelRaw) + ti_list.append((name, ti_model)) + except UnknownModelException: # print(e) # import traceback # print(traceback.format_exc()) - print(f'Warn: trigger: "{trigger}" not found') + logger.warning(f'trigger: "{trigger}" not found') + except ValueError: + logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models') with ( - ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as ( + ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as ( tokenizer, ti_manager, ), @@ -239,7 +244,7 @@ def _lora_loader(): # Apply the LoRA after text_encoder has been moved to its target device for faster patching. ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. - ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers), + ModelPatcher.apply_clip_skip(text_encoder_info.model, clip_field.skipped_layers), ): compel = Compel( tokenizer=tokenizer, @@ -357,6 +362,7 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput: dim=1, ) + assert c2_pooled is not None conditioning_data = ConditioningFieldData( conditionings=[ SDXLConditioningInfo( @@ -410,6 +416,7 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput: add_time_ids = torch.tensor([original_size + crop_coords + (self.aesthetic_score,)]) + assert c2_pooled is not None conditioning_data = ConditioningFieldData( conditionings=[ SDXLConditioningInfo( @@ -459,9 +466,9 @@ def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput: def get_max_token_count( - tokenizer, + tokenizer: CLIPTokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], - truncate_if_too_long=False, + truncate_if_too_long: bool = False, ) -> int: if type(prompt) is Blend: blend: Blend = prompt @@ -473,7 +480,9 @@ def get_max_token_count( return len(get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long)) -def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> List[str]: +def get_tokens_for_prompt_object( + tokenizer: CLIPTokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long: bool = True +) -> List[str]: if type(parsed_prompt) is Blend: raise ValueError("Blend is not supported here - you need to get tokens for each of its .children") @@ -486,24 +495,29 @@ def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, trun for x in parsed_prompt.children ] text = " ".join(text_fragments) - tokens = tokenizer.tokenize(text) + tokens: List[str] = tokenizer.tokenize(text) if truncate_if_too_long: max_tokens_length = tokenizer.model_max_length - 2 # typically 75 tokens = tokens[0:max_tokens_length] return tokens -def log_tokenization_for_conjunction(c: Conjunction, tokenizer, display_label_prefix=None): +def log_tokenization_for_conjunction( + c: Conjunction, tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None +) -> None: display_label_prefix = display_label_prefix or "" for i, p in enumerate(c.prompts): if len(c.prompts) > 1: this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})" else: + assert display_label_prefix is not None this_display_label_prefix = display_label_prefix log_tokenization_for_prompt_object(p, tokenizer, display_label_prefix=this_display_label_prefix) -def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None): +def log_tokenization_for_prompt_object( + p: Union[Blend, FlattenedPrompt], tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None +) -> None: display_label_prefix = display_label_prefix or "" if type(p) is Blend: blend: Blend = p @@ -543,7 +557,12 @@ def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokeniz log_tokenization_for_text(text, tokenizer, display_label=display_label_prefix) -def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False): +def log_tokenization_for_text( + text: str, + tokenizer: CLIPTokenizer, + display_label: Optional[str] = None, + truncate_if_too_long: Optional[bool] = False, +) -> None: """shows how the prompt is tokenized # usually tokens have '' to indicate end-of-word, # but for readability it has been replaced with ' ' diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 00c3fa74f6f..b9c20c79950 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -24,7 +24,7 @@ ) from controlnet_aux.util import HWC3, ade_palette from PIL import Image -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from pydantic import BaseModel, Field, field_validator, model_validator from invokeai.app.invocations.primitives import ImageField, ImageOutput from invokeai.app.invocations.util import validate_begin_end_step, validate_weights @@ -32,7 +32,6 @@ from invokeai.app.shared.fields import FieldDescriptions from invokeai.backend.image_util.depth_anything import DepthAnythingDetector -from ...backend.model_management import BaseModelType from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, @@ -57,10 +56,7 @@ class ControlNetModelField(BaseModel): """ControlNet model field""" - model_name: str = Field(description="Name of the ControlNet model") - base_model: BaseModelType = Field(description="Base model") - - model_config = ConfigDict(protected_namespaces=()) + key: str = Field(description="Model config record key for the ControlNet model") class ControlField(BaseModel): diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index 6bd28896244..6fc232b797d 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -1,8 +1,8 @@ -import os from builtins import float from typing import List, Union -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from pydantic import BaseModel, Field, field_validator, model_validator +from typing_extensions import Self from invokeai.app.invocations.baseinvocation import ( BaseInvocation, @@ -17,22 +17,16 @@ from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.shared.fields import FieldDescriptions -from invokeai.backend.model_management.models.base import BaseModelType, ModelType -from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id +from invokeai.backend.model_manager import BaseModelType, ModelType +# LS: Consider moving these two classes into model.py class IPAdapterModelField(BaseModel): - model_name: str = Field(description="Name of the IP-Adapter model") - base_model: BaseModelType = Field(description="Base model") - - model_config = ConfigDict(protected_namespaces=()) + key: str = Field(description="Key to the IP-Adapter model") class CLIPVisionModelField(BaseModel): - model_name: str = Field(description="Name of the CLIP Vision image encoder model") - base_model: BaseModelType = Field(description="Base model (usually 'Any')") - - model_config = ConfigDict(protected_namespaces=()) + key: str = Field(description="Key to the CLIP Vision image encoder model") class IPAdapterField(BaseModel): @@ -49,12 +43,12 @@ class IPAdapterField(BaseModel): @field_validator("weight") @classmethod - def validate_ip_adapter_weight(cls, v): + def validate_ip_adapter_weight(cls, v: float) -> float: validate_weights(v) return v @model_validator(mode="after") - def validate_begin_end_step_percent(self): + def validate_begin_end_step_percent(self) -> Self: validate_begin_end_step(self.begin_step_percent, self.end_step_percent) return self @@ -87,33 +81,25 @@ class IPAdapterInvocation(BaseInvocation): @field_validator("weight") @classmethod - def validate_ip_adapter_weight(cls, v): + def validate_ip_adapter_weight(cls, v: float) -> float: validate_weights(v) return v @model_validator(mode="after") - def validate_begin_end_step_percent(self): + def validate_begin_end_step_percent(self) -> Self: validate_begin_end_step(self.begin_step_percent, self.end_step_percent) return self def invoke(self, context: InvocationContext) -> IPAdapterOutput: # Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model. - ip_adapter_info = context.services.model_manager.model_info( - self.ip_adapter_model.model_name, self.ip_adapter_model.base_model, ModelType.IPAdapter - ) - # HACK(ryand): This is bad for a couple of reasons: 1) we are bypassing the model manager to read the model - # directly, and 2) we are reading from disk every time this invocation is called without caching the result. - # A better solution would be to store the image encoder model reference in the IP-Adapter model info, but this - # is currently messy due to differences between how the model info is generated when installing a model from - # disk vs. downloading the model. - image_encoder_model_id = get_ip_adapter_image_encoder_model_id( - os.path.join(context.services.configuration.get_config().models_path, ip_adapter_info["path"]) - ) + ip_adapter_info = context.services.model_manager.store.get_model(self.ip_adapter_model.key) + image_encoder_model_id = ip_adapter_info.image_encoder_model_id image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip() - image_encoder_model = CLIPVisionModelField( - model_name=image_encoder_model_name, - base_model=BaseModelType.Any, + image_encoder_models = context.services.model_manager.store.search_by_attr( + model_name=image_encoder_model_name, base_model=BaseModelType.Any, model_type=ModelType.CLIPVision ) + assert len(image_encoder_models) == 1 + image_encoder_model = CLIPVisionModelField(key=image_encoder_models[0].key) return IPAdapterOutput( ip_adapter=IPAdapterField( image=self.image, diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index b77363ceb86..7752db9d6ba 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -3,13 +3,15 @@ import math from contextlib import ExitStack from functools import singledispatchmethod -from typing import List, Literal, Optional, Union +from typing import Any, Iterator, List, Literal, Optional, Tuple, Union import einops import numpy as np +import numpy.typing as npt import torch import torchvision.transforms as T from diffusers import AutoencoderKL, AutoencoderTiny +from diffusers.configuration_utils import ConfigMixin from diffusers.image_processor import VaeImageProcessor from diffusers.models.adapter import T2IAdapter from diffusers.models.attention_processor import ( @@ -18,8 +20,10 @@ LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.schedulers import DPMSolverSDEScheduler from diffusers.schedulers import SchedulerMixin as Scheduler +from PIL import Image from pydantic import field_validator from torchvision.transforms.functional import resize as tv_resize @@ -39,13 +43,13 @@ from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus -from invokeai.backend.model_management.models import ModelType, SilenceWarnings +from invokeai.backend.lora import LoRAModelRaw +from invokeai.backend.model_manager import BaseModelType, LoadedModel +from invokeai.backend.model_patcher import ModelPatcher +from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo +from invokeai.backend.util.silence_warnings import SilenceWarnings -from ...backend.model_management.lora import ModelPatcher -from ...backend.model_management.models import BaseModelType -from ...backend.model_management.seamless import set_seamless -from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion.diffusers_pipeline import ( ControlNetData, IPAdapterData, @@ -77,7 +81,9 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device()) -SAMPLER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())] +SAMPLER_NAME_VALUES = Literal[ + tuple(SCHEDULER_MAP.keys()) +] # FIXME: "Invalid type alias". This defeats static type checking. # HACK: Many nodes are currently hard-coded to use a fixed latent scale factor of 8. This is fragile, and will need to # be addressed if future models use a different latent scale factor. Also, note that there may be places where the scale @@ -131,10 +137,10 @@ class CreateDenoiseMaskInvocation(BaseInvocation): ui_order=4, ) - def prep_mask_tensor(self, mask_image): + def prep_mask_tensor(self, mask_image: Image) -> torch.Tensor: if mask_image.mode != "L": mask_image = mask_image.convert("L") - mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False) + mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False) if mask_tensor.dim() == 3: mask_tensor = mask_tensor.unsqueeze(0) # if shape is not None: @@ -145,24 +151,24 @@ def prep_mask_tensor(self, mask_image): def invoke(self, context: InvocationContext) -> DenoiseMaskOutput: if self.image is not None: image = context.services.images.get_pil_image(self.image.image_name) - image = image_resized_to_grid_as_tensor(image.convert("RGB")) - if image.dim() == 3: - image = image.unsqueeze(0) + image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) + if image_tensor.dim() == 3: + image_tensor = image_tensor.unsqueeze(0) else: - image = None + image_tensor = None mask = self.prep_mask_tensor( context.services.images.get_pil_image(self.mask.image_name), ) - if image is not None: - vae_info = context.services.model_manager.get_model( + if image_tensor is not None: + vae_info = context.services.model_manager.load_model_by_key( **self.vae.vae.model_dump(), context=context, ) - img_mask = tv_resize(mask, image.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) - masked_image = image * torch.where(img_mask < 0.5, 0.0, 1.0) + img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) + masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0) # TODO: masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone()) @@ -189,7 +195,7 @@ def get_scheduler( seed: int, ) -> Scheduler: scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"]) - orig_scheduler_info = context.services.model_manager.get_model( + orig_scheduler_info = context.services.model_manager.load_model_by_key( **scheduler_info.model_dump(), context=context, ) @@ -200,7 +206,7 @@ def get_scheduler( scheduler_config = scheduler_config["_backup"] scheduler_config = { **scheduler_config, - **scheduler_extra_config, + **scheduler_extra_config, # FIXME "_backup": scheduler_config, } @@ -213,6 +219,7 @@ def get_scheduler( # hack copied over from generate.py if not hasattr(scheduler, "uses_inpainting_model"): scheduler.uses_inpainting_model = lambda: False + assert isinstance(scheduler, Scheduler) return scheduler @@ -296,7 +303,7 @@ class DenoiseLatentsInvocation(BaseInvocation): ) @field_validator("cfg_scale") - def ge_one(cls, v): + def ge_one(cls, v: Union[List[float], float]) -> Union[List[float], float]: """validate that all cfg_scale values are >= 1""" if isinstance(v, list): for i in v: @@ -326,9 +333,9 @@ def dispatch_progress( def get_conditioning_data( self, context: InvocationContext, - scheduler, - unet, - seed, + scheduler: Scheduler, + unet: UNet2DConditionModel, + seed: int, ) -> ConditioningData: positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name) c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) @@ -351,7 +358,7 @@ def get_conditioning_data( ), ) - conditioning_data = conditioning_data.add_scheduler_args_if_applicable( + conditioning_data = conditioning_data.add_scheduler_args_if_applicable( # FIXME scheduler, # for ddim scheduler eta=0.0, # ddim_eta @@ -363,8 +370,8 @@ def get_conditioning_data( def create_pipeline( self, - unet, - scheduler, + unet: UNet2DConditionModel, + scheduler: Scheduler, ) -> StableDiffusionGeneratorPipeline: # TODO: # configure_model_padding( @@ -375,10 +382,10 @@ def create_pipeline( class FakeVae: class FakeVaeConfig: - def __init__(self): + def __init__(self) -> None: self.block_out_channels = [0] - def __init__(self): + def __init__(self) -> None: self.config = FakeVae.FakeVaeConfig() return StableDiffusionGeneratorPipeline( @@ -395,11 +402,11 @@ def __init__(self): def prep_control_data( self, context: InvocationContext, - control_input: Union[ControlField, List[ControlField]], + control_input: Optional[Union[ControlField, List[ControlField]]], latents_shape: List[int], exit_stack: ExitStack, do_classifier_free_guidance: bool = True, - ) -> List[ControlNetData]: + ) -> Optional[List[ControlNetData]]: # Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR. control_height_resize = latents_shape[2] * LATENT_SCALE_FACTOR control_width_resize = latents_shape[3] * LATENT_SCALE_FACTOR @@ -422,10 +429,8 @@ def prep_control_data( controlnet_data = [] for control_info in control_list: control_model = exit_stack.enter_context( - context.services.model_manager.get_model( - model_name=control_info.control_model.model_name, - model_type=ModelType.ControlNet, - base_model=control_info.control_model.base_model, + context.services.model_manager.load_model_by_key( + key=control_info.control_model.key, context=context, ) ) @@ -490,27 +495,25 @@ def prep_ip_adapter_data( conditioning_data.ip_adapter_conditioning = [] for single_ip_adapter in ip_adapter: ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context( - context.services.model_manager.get_model( - model_name=single_ip_adapter.ip_adapter_model.model_name, - model_type=ModelType.IPAdapter, - base_model=single_ip_adapter.ip_adapter_model.base_model, + context.services.model_manager.load_model_by_key( + key=single_ip_adapter.ip_adapter_model.key, context=context, ) ) - image_encoder_model_info = context.services.model_manager.get_model( - model_name=single_ip_adapter.image_encoder_model.model_name, - model_type=ModelType.CLIPVision, - base_model=single_ip_adapter.image_encoder_model.base_model, + image_encoder_model_info = context.services.model_manager.load_model_by_key( + key=single_ip_adapter.image_encoder_model.key, context=context, ) # `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here. - single_ipa_images = single_ip_adapter.image - if not isinstance(single_ipa_images, list): - single_ipa_images = [single_ipa_images] + single_ipa_image_fields = single_ip_adapter.image + if not isinstance(single_ipa_image_fields, list): + single_ipa_image_fields = [single_ipa_image_fields] - single_ipa_images = [context.services.images.get_pil_image(image.image_name) for image in single_ipa_images] + single_ipa_images = [ + context.services.images.get_pil_image(image.image_name) for image in single_ipa_image_fields + ] # TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other # models are needed in memory. This would help to reduce peak memory utilization in low-memory environments. @@ -554,23 +557,19 @@ def run_t2i_adapters( t2i_adapter_data = [] for t2i_adapter_field in t2i_adapter: - t2i_adapter_model_info = context.services.model_manager.get_model( - model_name=t2i_adapter_field.t2i_adapter_model.model_name, - model_type=ModelType.T2IAdapter, - base_model=t2i_adapter_field.t2i_adapter_model.base_model, + t2i_adapter_model_info = context.services.model_manager.load_model_by_key( + key=t2i_adapter_field.t2i_adapter_model.key, context=context, ) image = context.services.images.get_pil_image(t2i_adapter_field.image.image_name) # The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally. - if t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusion1: + if t2i_adapter_model_info.base == BaseModelType.StableDiffusion1: max_unet_downscale = 8 - elif t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusionXL: + elif t2i_adapter_model_info.base == BaseModelType.StableDiffusionXL: max_unet_downscale = 4 else: - raise ValueError( - f"Unexpected T2I-Adapter base model type: '{t2i_adapter_field.t2i_adapter_model.base_model}'." - ) + raise ValueError(f"Unexpected T2I-Adapter base model type: '{t2i_adapter_model_info.base}'.") t2i_adapter_model: T2IAdapter with t2i_adapter_model_info as t2i_adapter_model: @@ -593,7 +592,7 @@ def run_t2i_adapters( do_classifier_free_guidance=False, width=t2i_input_width, height=t2i_input_height, - num_channels=t2i_adapter_model.config.in_channels, + num_channels=t2i_adapter_model.config["in_channels"], # mypy treats this as a FrozenDict device=t2i_adapter_model.device, dtype=t2i_adapter_model.dtype, resize_mode=t2i_adapter_field.resize_mode, @@ -618,7 +617,15 @@ def run_t2i_adapters( # original idea by https://github.com/AmericanPresidentJimmyCarter # TODO: research more for second order schedulers timesteps - def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_end): + def init_scheduler( + self, + scheduler: Union[Scheduler, ConfigMixin], + device: torch.device, + steps: int, + denoising_start: float, + denoising_end: float, + ) -> Tuple[int, List[int], int]: + assert isinstance(scheduler, ConfigMixin) if scheduler.config.get("cpu_only", False): scheduler.set_timesteps(steps, device="cpu") timesteps = scheduler.timesteps.to(device=device) @@ -630,11 +637,11 @@ def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_en _timesteps = timesteps[:: scheduler.order] # get start timestep index - t_start_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_start))) + t_start_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_start))) t_start_idx = len(list(filter(lambda ts: ts >= t_start_val, _timesteps))) # get end timestep index - t_end_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_end))) + t_end_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_end))) t_end_idx = len(list(filter(lambda ts: ts >= t_end_val, _timesteps[t_start_idx:]))) # apply order to indexes @@ -647,7 +654,9 @@ def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_en return num_inference_steps, timesteps, init_timestep - def prep_inpaint_mask(self, context, latents): + def prep_inpaint_mask( + self, context: InvocationContext, latents: torch.Tensor + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: if self.denoise_mask is None: return None, None @@ -700,31 +709,36 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) source_node_id = graph_execution_state.prepared_source_mapping[self.id] - def step_callback(state: PipelineIntermediateState): - self.dispatch_progress(context, source_node_id, state, self.unet.unet.base_model) + # get the unet's config so that we can pass the base to dispatch_progress() + unet_config = context.services.model_manager.store.get_model(self.unet.unet.key) - def _lora_loader(): + def step_callback(state: PipelineIntermediateState) -> None: + self.dispatch_progress(context, source_node_id, state, unet_config.base) + + def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in self.unet.loras: - lora_info = context.services.model_manager.get_model( + lora_info = context.services.model_manager.load_model_by_key( **lora.model_dump(exclude={"weight"}), context=context, ) - yield (lora_info.context.model, lora.weight) + yield (lora_info.model, lora.weight) del lora_info return - unet_info = context.services.model_manager.get_model( + unet_info = context.services.model_manager.load_model_by_key( **self.unet.unet.model_dump(), context=context, ) + assert isinstance(unet_info.model, UNet2DConditionModel) with ( ExitStack() as exit_stack, - ModelPatcher.apply_freeu(unet_info.context.model, self.unet.freeu_config), - set_seamless(unet_info.context.model, self.unet.seamless_axes), + ModelPatcher.apply_freeu(unet_info.model, self.unet.freeu_config), + set_seamless(unet_info.model, self.unet.seamless_axes), # FIXME unet_info as unet, # Apply the LoRA after unet has been moved to its target device for faster patching. ModelPatcher.apply_lora_unet(unet, _lora_loader()), ): + assert isinstance(unet, UNet2DConditionModel) latents = latents.to(device=unet.device, dtype=unet.dtype) if noise is not None: noise = noise.to(device=unet.device, dtype=unet.dtype) @@ -822,12 +836,13 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata): def invoke(self, context: InvocationContext) -> ImageOutput: latents = context.services.latents.get(self.latents.latents_name) - vae_info = context.services.model_manager.get_model( + vae_info = context.services.model_manager.load_model_by_key( **self.vae.vae.model_dump(), context=context, ) - with set_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae: + with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae: + assert isinstance(vae, torch.nn.Module) latents = latents.to(vae.device) if self.fp32: vae.to(dtype=torch.float32) @@ -1016,8 +1031,9 @@ class ImageToLatentsInvocation(BaseInvocation): fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32) @staticmethod - def vae_encode(vae_info, upcast, tiled, image_tensor): + def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor: with vae_info as vae: + assert isinstance(vae, torch.nn.Module) orig_dtype = vae.dtype if upcast: vae.to(dtype=torch.float32) @@ -1063,7 +1079,7 @@ def vae_encode(vae_info, upcast, tiled, image_tensor): def invoke(self, context: InvocationContext) -> LatentsOutput: image = context.services.images.get_pil_image(self.image.image_name) - vae_info = context.services.model_manager.get_model( + vae_info = context.services.model_manager.load_model_by_key( **self.vae.vae.model_dump(), context=context, ) @@ -1082,14 +1098,19 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: @singledispatchmethod @staticmethod def _encode_to_tensor(vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor: + assert isinstance(vae, torch.nn.Module) image_tensor_dist = vae.encode(image_tensor).latent_dist - latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible! + latents: torch.Tensor = image_tensor_dist.sample().to( + dtype=vae.dtype + ) # FIXME: uses torch.randn. make reproducible! return latents @_encode_to_tensor.register @staticmethod def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor: - return vae.encode(image_tensor).latents + assert isinstance(vae, torch.nn.Module) + latents: torch.FloatTensor = vae.encode(image_tensor).latents + return latents @invocation( @@ -1122,7 +1143,12 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: # TODO: device = choose_torch_device() - def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): + def slerp( + t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here? + v0: Union[torch.Tensor, npt.NDArray[Any]], + v1: Union[torch.Tensor, npt.NDArray[Any]], + DOT_THRESHOLD: float = 0.9995, + ) -> Union[torch.Tensor, npt.NDArray[Any]]: """ Spherical linear interpolation Args: @@ -1155,12 +1181,16 @@ def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): v2 = s0 * v0 + s1 * v1 if inputs_are_torch: - v2 = torch.from_numpy(v2).to(device) - - return v2 + v2_torch: torch.Tensor = torch.from_numpy(v2).to(device) + return v2_torch + else: + assert isinstance(v2, np.ndarray) + return v2 # blend - blended_latents = slerp(self.alpha, latents_a, latents_b) + bl = slerp(self.alpha, latents_a, latents_b) + assert isinstance(bl, torch.Tensor) + blended_latents: torch.Tensor = bl # for type checking convenience # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 blended_latents = blended_latents.to("cpu") @@ -1256,15 +1286,19 @@ class IdealSizeInvocation(BaseInvocation): description="Amount to multiply the model's dimensions by when calculating the ideal size (may result in initial generation artifacts if too large)", ) - def trim_to_multiple_of(self, *args, multiple_of=LATENT_SCALE_FACTOR): + def trim_to_multiple_of(self, *args: int, multiple_of: int = LATENT_SCALE_FACTOR) -> Tuple[int, ...]: return tuple((x - x % multiple_of) for x in args) def invoke(self, context: InvocationContext) -> IdealSizeOutput: + unet_config = context.services.model_manager.load_model_by_key( + **self.unet.unet.model_dump(), + context=context, + ) aspect = self.width / self.height - dimension = 512 - if self.unet.unet.base_model == BaseModelType.StableDiffusion2: + dimension: float = 512 + if unet_config.base == BaseModelType.StableDiffusion2: dimension = 768 - elif self.unet.unet.base_model == BaseModelType.StableDiffusionXL: + elif unet_config.base == BaseModelType.StableDiffusionXL: dimension = 1024 dimension = dimension * self.multiplier min_dimension = math.floor(dimension * 0.5) diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 99dcc72999b..23814718997 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -1,12 +1,12 @@ import copy from typing import List, Optional -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, Field from invokeai.app.shared.fields import FieldDescriptions from invokeai.app.shared.models import FreeUConfig -from ...backend.model_management import BaseModelType, ModelType, SubModelType +from ...backend.model_manager import SubModelType from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, @@ -20,12 +20,8 @@ class ModelInfo(BaseModel): - model_name: str = Field(description="Info to load submodel") - base_model: BaseModelType = Field(description="Base model") - model_type: ModelType = Field(description="Info to load submodel") - submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel") - - model_config = ConfigDict(protected_namespaces=()) + key: str = Field(description="Key of model as returned by ModelRecordServiceBase.get_model()") + submodel_type: Optional[SubModelType] = Field(default=None, description="Info to load submodel") class LoraInfo(ModelInfo): @@ -55,7 +51,7 @@ class VaeField(BaseModel): @invocation_output("unet_output") class UNetOutput(BaseInvocationOutput): - """Base class for invocations that output a UNet field""" + """Base class for invocations that output a UNet field.""" unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet") @@ -84,20 +80,13 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput): class MainModelField(BaseModel): """Main model field""" - model_name: str = Field(description="Name of the model") - base_model: BaseModelType = Field(description="Base model") - model_type: ModelType = Field(description="Model Type") - - model_config = ConfigDict(protected_namespaces=()) + key: str = Field(description="Model key") class LoRAModelField(BaseModel): """LoRA model field""" - model_name: str = Field(description="Name of the LoRA model") - base_model: BaseModelType = Field(description="Base model") - - model_config = ConfigDict(protected_namespaces=()) + key: str = Field(description="LoRA model key") @invocation( @@ -114,85 +103,40 @@ class MainModelLoaderInvocation(BaseInvocation): # TODO: precision? def invoke(self, context: InvocationContext) -> ModelLoaderOutput: - base_model = self.model.base_model - model_name = self.model.model_name - model_type = ModelType.Main + key = self.model.key # TODO: not found exceptions - if not context.services.model_manager.model_exists( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ): - raise Exception(f"Unknown {base_model} {model_type} model: {model_name}") - - """ - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.Tokenizer, - ): - raise Exception( - f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted" - ) - - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.TextEncoder, - ): - raise Exception( - f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted" - ) - - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.UNet, - ): - raise Exception( - f"Failed to find unet submodel from {self.model_name}! Check if model corrupted" - ) - """ + if not context.services.model_manager.store.exists(key): + raise Exception(f"Unknown model {key}") return ModelLoaderOutput( unet=UNetField( unet=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.UNet, + key=key, + submodel_type=SubModelType.UNet, ), scheduler=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.Scheduler, + key=key, + submodel_type=SubModelType.Scheduler, ), loras=[], ), clip=ClipField( tokenizer=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.Tokenizer, + key=key, + submodel_type=SubModelType.Tokenizer, ), text_encoder=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.TextEncoder, + key=key, + submodel_type=SubModelType.TextEncoder, ), loras=[], skipped_layers=0, ), vae=VaeField( vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.Vae, + key=key, + submodel_type=SubModelType.Vae, ), ), ) @@ -229,21 +173,16 @@ def invoke(self, context: InvocationContext) -> LoraLoaderOutput: if self.lora is None: raise Exception("No LoRA provided") - base_model = self.lora.base_model - lora_name = self.lora.model_name + lora_key = self.lora.key - if not context.services.model_manager.model_exists( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, - ): - raise Exception(f"Unkown lora name: {lora_name}!") + if not context.services.model_manager.store.exists(lora_key): + raise Exception(f"Unkown lora: {lora_key}!") - if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras): - raise Exception(f'Lora "{lora_name}" already applied to unet') + if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras): + raise Exception(f'Lora "{lora_key}" already applied to unet') - if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras): - raise Exception(f'Lora "{lora_name}" already applied to clip') + if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras): + raise Exception(f'Lora "{lora_key}" already applied to clip') output = LoraLoaderOutput() @@ -251,10 +190,8 @@ def invoke(self, context: InvocationContext) -> LoraLoaderOutput: output.unet = copy.deepcopy(self.unet) output.unet.loras.append( LoraInfo( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, - submodel=None, + key=lora_key, + submodel_type=None, weight=self.weight, ) ) @@ -263,10 +200,8 @@ def invoke(self, context: InvocationContext) -> LoraLoaderOutput: output.clip = copy.deepcopy(self.clip) output.clip.loras.append( LoraInfo( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, - submodel=None, + key=lora_key, + submodel_type=None, weight=self.weight, ) ) @@ -318,24 +253,19 @@ def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: if self.lora is None: raise Exception("No LoRA provided") - base_model = self.lora.base_model - lora_name = self.lora.model_name + lora_key = self.lora.key - if not context.services.model_manager.model_exists( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, - ): - raise Exception(f"Unknown lora name: {lora_name}!") + if not context.services.model_manager.store.exists(lora_key): + raise Exception(f"Unknown lora: {lora_key}!") - if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras): - raise Exception(f'Lora "{lora_name}" already applied to unet') + if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras): + raise Exception(f'Lora "{lora_key}" already applied to unet') - if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras): - raise Exception(f'Lora "{lora_name}" already applied to clip') + if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras): + raise Exception(f'Lora "{lora_key}" already applied to clip') - if self.clip2 is not None and any(lora.model_name == lora_name for lora in self.clip2.loras): - raise Exception(f'Lora "{lora_name}" already applied to clip2') + if self.clip2 is not None and any(lora.key == lora_key for lora in self.clip2.loras): + raise Exception(f'Lora "{lora_key}" already applied to clip2') output = SDXLLoraLoaderOutput() @@ -343,10 +273,8 @@ def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: output.unet = copy.deepcopy(self.unet) output.unet.loras.append( LoraInfo( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, - submodel=None, + key=lora_key, + submodel_type=None, weight=self.weight, ) ) @@ -355,10 +283,8 @@ def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: output.clip = copy.deepcopy(self.clip) output.clip.loras.append( LoraInfo( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, - submodel=None, + key=lora_key, + submodel_type=None, weight=self.weight, ) ) @@ -367,10 +293,8 @@ def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: output.clip2 = copy.deepcopy(self.clip2) output.clip2.loras.append( LoraInfo( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, - submodel=None, + key=lora_key, + submodel_type=None, weight=self.weight, ) ) @@ -381,10 +305,7 @@ def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: class VAEModelField(BaseModel): """Vae model field""" - model_name: str = Field(description="Name of the model") - base_model: BaseModelType = Field(description="Base model") - - model_config = ConfigDict(protected_namespaces=()) + key: str = Field(description="Model's key") @invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.0") @@ -398,25 +319,12 @@ class VaeLoaderInvocation(BaseInvocation): ) def invoke(self, context: InvocationContext) -> VAEOutput: - base_model = self.vae_model.base_model - model_name = self.vae_model.model_name - model_type = ModelType.Vae - - if not context.services.model_manager.model_exists( - base_model=base_model, - model_name=model_name, - model_type=model_type, - ): - raise Exception(f"Unkown vae name: {model_name}!") - return VAEOutput( - vae=VaeField( - vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ) - ) - ) + key = self.vae_model.key + + if not context.services.model_manager.store.exists(key): + raise Exception(f"Unkown vae: {key}!") + + return VAEOutput(vae=VaeField(vae=ModelInfo(key=key))) @invocation_output("seamless_output") diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py index 759cfde700f..02f69bed745 100644 --- a/invokeai/app/invocations/onnx.py +++ b/invokeai/app/invocations/onnx.py @@ -8,16 +8,16 @@ import numpy as np import torch from diffusers.image_processor import VaeImageProcessor -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import BaseModel, Field, field_validator from tqdm import tqdm from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin from invokeai.app.shared.fields import FieldDescriptions from invokeai.app.util.step_callback import stable_diffusion_step_callback -from invokeai.backend import BaseModelType, ModelType, SubModelType +from invokeai.backend.model_manager import ModelType, SubModelType +from invokeai.backend.model_patcher import ONNXModelPatcher -from ...backend.model_management import ONNXModelPatcher from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.util import choose_torch_device from ..util.ti_utils import extract_ti_triggers_from_prompt @@ -62,16 +62,16 @@ class ONNXPromptInvocation(BaseInvocation): clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection) def invoke(self, context: InvocationContext) -> ConditioningOutput: - tokenizer_info = context.services.model_manager.get_model( + tokenizer_info = context.services.model_manager.load_model_by_key( **self.clip.tokenizer.model_dump(), ) - text_encoder_info = context.services.model_manager.get_model( + text_encoder_info = context.services.model_manager.load_model_by_key( **self.clip.text_encoder.model_dump(), ) with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder: # , ExitStack() as stack: loras = [ ( - context.services.model_manager.get_model(**lora.model_dump(exclude={"weight"})).context.model, + context.services.model_manager.load_model_by_key(**lora.model_dump(exclude={"weight"})).model, lora.weight, ) for lora in self.clip.loras @@ -84,11 +84,11 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput: ti_list.append( ( name, - context.services.model_manager.get_model( + context.services.model_manager.load_model_by_attr( model_name=name, - base_model=self.clip.text_encoder.base_model, + base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion, - ).context.model, + ).model, ) ) except Exception: @@ -257,13 +257,13 @@ def dispatch_progress( eta=0.0, ) - unet_info = context.services.model_manager.get_model(**self.unet.unet.model_dump()) + unet_info = context.services.model_manager.load_model_by_key(**self.unet.unet.model_dump()) with unet_info as unet: # , ExitStack() as stack: # loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras] loras = [ ( - context.services.model_manager.get_model(**lora.model_dump(exclude={"weight"})).context.model, + context.services.model_manager.load_model_by_key(**lora.model_dump(exclude={"weight"})).model, lora.weight, ) for lora in self.unet.loras @@ -344,9 +344,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: latents = context.services.latents.get(self.latents.latents_name) if self.vae.vae.submodel != SubModelType.VaeDecoder: - raise Exception(f"Expected vae_decoder, found: {self.vae.vae.model_type}") + raise Exception(f"Expected vae_decoder, found: {self.vae.vae.submodel}") - vae_info = context.services.model_manager.get_model( + vae_info = context.services.model_manager.load_model_by_key( **self.vae.vae.model_dump(), ) @@ -400,11 +400,7 @@ class ONNXModelLoaderOutput(BaseInvocationOutput): class OnnxModelField(BaseModel): """Onnx model field""" - model_name: str = Field(description="Name of the model") - base_model: BaseModelType = Field(description="Base model") - model_type: ModelType = Field(description="Model Type") - - model_config = ConfigDict(protected_namespaces=()) + key: str = Field(description="Model ID") @invocation("onnx_model_loader", title="ONNX Main Model", tags=["onnx", "model"], category="model", version="1.0.0") @@ -416,93 +412,46 @@ class OnnxModelLoaderInvocation(BaseInvocation): ) def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput: - base_model = self.model.base_model - model_name = self.model.model_name - model_type = ModelType.ONNX + model_key = self.model.key # TODO: not found exceptions - if not context.services.model_manager.model_exists( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ): - raise Exception(f"Unknown {base_model} {model_type} model: {model_name}") - - """ - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.Tokenizer, - ): - raise Exception( - f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted" - ) - - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.TextEncoder, - ): - raise Exception( - f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted" - ) - - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.UNet, - ): - raise Exception( - f"Failed to find unet submodel from {self.model_name}! Check if model corrupted" - ) - """ + if not context.services.model_manager.store.exists(model_key): + raise Exception(f"Unknown model: {model_key}") return ONNXModelLoaderOutput( unet=UNetField( unet=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.UNet, + key=model_key, + submodel_type=SubModelType.UNet, ), scheduler=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.Scheduler, + key=model_key, + submodel_type=SubModelType.Scheduler, ), loras=[], ), clip=ClipField( tokenizer=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.Tokenizer, + key=model_key, + submodel_type=SubModelType.Tokenizer, ), text_encoder=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.TextEncoder, + key=model_key, + submodel_type=SubModelType.TextEncoder, ), loras=[], skipped_layers=0, ), vae_decoder=VaeField( vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.VaeDecoder, + key=model_key, + submodel_type=SubModelType.VaeDecoder, ), ), vae_encoder=VaeField( vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.VaeEncoder, + key=model_key, + submodel_type=SubModelType.VaeEncoder, ), ), ) diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index afe8ff06d9d..09c9b7f3cac 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -368,7 +368,7 @@ def invoke(self, context: InvocationContext) -> LatentsCollectionOutput: return LatentsCollectionOutput(collection=self.collection) -def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int] = None): +def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int] = None) -> LatentsOutput: return LatentsOutput( latents=LatentsField(latents_name=latents_name, seed=seed), width=latents.size()[3] * 8, diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 68076fdfeb1..c38e5448c8e 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -1,6 +1,6 @@ from invokeai.app.shared.fields import FieldDescriptions +from invokeai.backend.model_manager import SubModelType -from ...backend.model_management import ModelType, SubModelType from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, @@ -44,72 +44,52 @@ class SDXLModelLoaderInvocation(BaseInvocation): # TODO: precision? def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput: - base_model = self.model.base_model - model_name = self.model.model_name - model_type = ModelType.Main + model_key = self.model.key # TODO: not found exceptions - if not context.services.model_manager.model_exists( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ): - raise Exception(f"Unknown {base_model} {model_type} model: {model_name}") + if not context.services.model_manager.store.exists(model_key): + raise Exception(f"Unknown model: {model_key}") return SDXLModelLoaderOutput( unet=UNetField( unet=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.UNet, + key=model_key, + submodel_type=SubModelType.UNet, ), scheduler=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.Scheduler, + key=model_key, + submodel_type=SubModelType.Scheduler, ), loras=[], ), clip=ClipField( tokenizer=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.Tokenizer, + key=model_key, + submodel_type=SubModelType.Tokenizer, ), text_encoder=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.TextEncoder, + key=model_key, + submodel_type=SubModelType.TextEncoder, ), loras=[], skipped_layers=0, ), clip2=ClipField( tokenizer=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.Tokenizer2, + key=model_key, + submodel_type=SubModelType.Tokenizer2, ), text_encoder=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.TextEncoder2, + key=model_key, + submodel_type=SubModelType.TextEncoder2, ), loras=[], skipped_layers=0, ), vae=VaeField( vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.Vae, + key=model_key, + submodel_type=SubModelType.Vae, ), ), ) @@ -133,56 +113,40 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation): # TODO: precision? def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput: - base_model = self.model.base_model - model_name = self.model.model_name - model_type = ModelType.Main + model_key = self.model.key # TODO: not found exceptions - if not context.services.model_manager.model_exists( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ): - raise Exception(f"Unknown {base_model} {model_type} model: {model_name}") + if not context.services.model_manager.store.exists(model_key): + raise Exception(f"Unknown model: {model_key}") return SDXLRefinerModelLoaderOutput( unet=UNetField( unet=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.UNet, + key=model_key, + submodel_type=SubModelType.UNet, ), scheduler=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.Scheduler, + key=model_key, + submodel_type=SubModelType.Scheduler, ), loras=[], ), clip2=ClipField( tokenizer=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.Tokenizer2, + key=model_key, + submodel_type=SubModelType.Tokenizer2, ), text_encoder=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.TextEncoder2, + key=model_key, + submodel_type=SubModelType.TextEncoder2, ), loras=[], skipped_layers=0, ), vae=VaeField( vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.Vae, + key=model_key, + submodel_type=SubModelType.Vae, ), ), ) diff --git a/invokeai/app/invocations/t2i_adapter.py b/invokeai/app/invocations/t2i_adapter.py index e055d23903f..09819672b71 100644 --- a/invokeai/app/invocations/t2i_adapter.py +++ b/invokeai/app/invocations/t2i_adapter.py @@ -1,6 +1,6 @@ from typing import Union -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from pydantic import BaseModel, Field, field_validator, model_validator from invokeai.app.invocations.baseinvocation import ( BaseInvocation, @@ -16,14 +16,10 @@ from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.shared.fields import FieldDescriptions -from invokeai.backend.model_management.models.base import BaseModelType class T2IAdapterModelField(BaseModel): - model_name: str = Field(description="Name of the T2I-Adapter model") - base_model: BaseModelType = Field(description="Base model") - - model_config = ConfigDict(protected_namespaces=()) + key: str = Field(description="Model record key for the T2I-Adapter model") class T2IAdapterField(BaseModel): diff --git a/invokeai/app/services/config/config_base.py b/invokeai/app/services/config/config_base.py index a304b38a955..c73aa438096 100644 --- a/invokeai/app/services/config/config_base.py +++ b/invokeai/app/services/config/config_base.py @@ -27,11 +27,11 @@ class InvokeAISettings(BaseSettings): """Runtime configuration settings in which default values are read from an omegaconf .yaml file.""" initconf: ClassVar[Optional[DictConfig]] = None - argparse_groups: ClassVar[Dict] = {} + argparse_groups: ClassVar[Dict[str, Any]] = {} model_config = SettingsConfigDict(env_file_encoding="utf-8", arbitrary_types_allowed=True, case_sensitive=True) - def parse_args(self, argv: Optional[list] = sys.argv[1:]): + def parse_args(self, argv: Optional[List[str]] = sys.argv[1:]) -> None: """Call to parse command-line arguments.""" parser = self.get_parser() opt, unknown_opts = parser.parse_known_args(argv) @@ -68,7 +68,7 @@ def to_yaml(self) -> str: return OmegaConf.to_yaml(conf) @classmethod - def add_parser_arguments(cls, parser): + def add_parser_arguments(cls, parser: ArgumentParser) -> None: """Dynamically create arguments for a settings parser.""" if "type" in get_type_hints(cls): settings_stanza = get_args(get_type_hints(cls)["type"])[0] @@ -117,7 +117,8 @@ def cmd_name(cls, command_field: str = "type") -> str: """Return the category of a setting.""" hints = get_type_hints(cls) if command_field in hints: - return get_args(hints[command_field])[0] + result: str = get_args(hints[command_field])[0] + return result else: return "Uncategorized" @@ -158,7 +159,7 @@ def _excluded_from_yaml(cls) -> List[str]: ] @classmethod - def add_field_argument(cls, command_parser, name: str, field, default_override=None): + def add_field_argument(cls, command_parser, name: str, field, default_override=None) -> None: """Add the argparse arguments for a setting parser.""" field_type = get_type_hints(cls).get(name) default = ( diff --git a/invokeai/app/services/config/config_common.py b/invokeai/app/services/config/config_common.py index d11bcabcf9c..27a0f859c23 100644 --- a/invokeai/app/services/config/config_common.py +++ b/invokeai/app/services/config/config_common.py @@ -21,7 +21,7 @@ class PagingArgumentParser(argparse.ArgumentParser): It also supports reading defaults from an init file. """ - def print_help(self, file=None): + def print_help(self, file=None) -> None: text = self.format_help() pydoc.pager(text) diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 132afc22722..2af775372dd 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -173,7 +173,7 @@ class InvokeBatch(InvokeAISettings): import os from pathlib import Path -from typing import Any, ClassVar, Dict, List, Literal, Optional, Union +from typing import Any, ClassVar, Dict, List, Literal, Optional from omegaconf import DictConfig, OmegaConf from pydantic import Field @@ -185,7 +185,9 @@ class InvokeBatch(InvokeAISettings): INIT_FILE = Path("invokeai.yaml") DB_FILE = Path("invokeai.db") LEGACY_INIT_FILE = Path("invokeai.init") -DEFAULT_MAX_VRAM = 0.5 +DEFAULT_RAM_CACHE = 10.0 +DEFAULT_VRAM_CACHE = 0.25 +DEFAULT_CONVERT_CACHE = 20.0 class Categories(object): @@ -237,6 +239,7 @@ class InvokeAIAppConfig(InvokeAISettings): autoimport_dir : Path = Field(default=Path('autoimport'), description='Path to a directory of models files to be imported on startup.', json_schema_extra=Categories.Paths) conf_path : Path = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Paths) models_dir : Path = Field(default=Path('models'), description='Path to the models directory', json_schema_extra=Categories.Paths) + convert_cache_dir : Path = Field(default=Path('models/.cache'), description='Path to the converted models cache directory', json_schema_extra=Categories.Paths) legacy_conf_dir : Path = Field(default=Path('configs/stable-diffusion'), description='Path to directory of legacy checkpoint config files', json_schema_extra=Categories.Paths) db_dir : Path = Field(default=Path('databases'), description='Path to InvokeAI databases directory', json_schema_extra=Categories.Paths) outdir : Path = Field(default=Path('outputs'), description='Default folder for output images', json_schema_extra=Categories.Paths) @@ -260,8 +263,10 @@ class InvokeAIAppConfig(InvokeAISettings): version : bool = Field(default=False, description="Show InvokeAI version and exit", json_schema_extra=Categories.Other) # CACHE - ram : float = Field(default=7.5, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", json_schema_extra=Categories.ModelCache, ) - vram : float = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", json_schema_extra=Categories.ModelCache, ) + ram : float = Field(default=DEFAULT_RAM_CACHE, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", json_schema_extra=Categories.ModelCache, ) + vram : float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", json_schema_extra=Categories.ModelCache, ) + convert_cache : float = Field(default=DEFAULT_CONVERT_CACHE, ge=0, description="Maximum size of on-disk converted models cache (GB)", json_schema_extra=Categories.ModelCache) + lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", json_schema_extra=Categories.ModelCache, ) log_memory_usage : bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.", json_schema_extra=Categories.ModelCache) @@ -404,6 +409,11 @@ def models_path(self) -> Path: """Path to the models directory.""" return self._resolve(self.models_dir) + @property + def models_convert_cache_path(self) -> Path: + """Path to the converted cache models directory.""" + return self._resolve(self.convert_cache_dir) + @property def custom_nodes_path(self) -> Path: """Path to the custom nodes directory.""" @@ -433,15 +443,20 @@ def invisible_watermark(self) -> bool: return True @property - def ram_cache_size(self) -> Union[Literal["auto"], float]: - """Return the ram cache size using the legacy or modern setting.""" + def ram_cache_size(self) -> float: + """Return the ram cache size using the legacy or modern setting (GB).""" return self.max_cache_size or self.ram @property - def vram_cache_size(self) -> Union[Literal["auto"], float]: - """Return the vram cache size using the legacy or modern setting.""" + def vram_cache_size(self) -> float: + """Return the vram cache size using the legacy or modern setting (GB).""" return self.max_vram_cache_size or self.vram + @property + def convert_cache_size(self) -> float: + """Return the convert cache size on disk (GB).""" + return self.convert_cache + @property def use_cpu(self) -> bool: """Return true if the device is set to CPU or the always_use_cpu flag is set.""" diff --git a/invokeai/app/services/download/download_base.py b/invokeai/app/services/download/download_base.py index f854f64f585..2ac13b825fe 100644 --- a/invokeai/app/services/download/download_base.py +++ b/invokeai/app/services/download/download_base.py @@ -260,3 +260,16 @@ def cancel_job(self, job: DownloadJob) -> None: def join(self) -> None: """Wait until all jobs are off the queue.""" pass + + @abstractmethod + def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob: + """Wait until the indicated download job has reached a terminal state. + + This will block until the indicated install job has completed, + been cancelled, or errored out. + + :param job: The job to wait on. + :param timeout: Wait up to indicated number of seconds. Raise a TimeoutError if + the job hasn't completed within the indicated time. + """ + pass diff --git a/invokeai/app/services/download/download_default.py b/invokeai/app/services/download/download_default.py index 7613c0893fc..50cac80d094 100644 --- a/invokeai/app/services/download/download_default.py +++ b/invokeai/app/services/download/download_default.py @@ -4,10 +4,11 @@ import os import re import threading +import time import traceback from pathlib import Path from queue import Empty, PriorityQueue -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Set import requests from pydantic.networks import AnyHttpUrl @@ -48,11 +49,12 @@ def __init__( :param max_parallel_dl: Number of simultaneous downloads allowed [5]. :param requests_session: Optional requests.sessions.Session object, for unit tests. """ - self._jobs = {} + self._jobs: Dict[int, DownloadJob] = {} self._next_job_id = 0 - self._queue = PriorityQueue() + self._queue: PriorityQueue[DownloadJob] = PriorityQueue() self._stop_event = threading.Event() - self._worker_pool = set() + self._job_completed_event = threading.Event() + self._worker_pool: Set[threading.Thread] = set() self._lock = threading.Lock() self._logger = InvokeAILogger.get_logger("DownloadQueueService") self._event_bus = event_bus @@ -188,6 +190,16 @@ def cancel_all_jobs(self) -> None: if not job.in_terminal_state: self.cancel_job(job) + def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob: + """Block until the indicated job has reached terminal state, or when timeout limit reached.""" + start = time.time() + while not job.in_terminal_state: + if self._job_completed_event.wait(timeout=0.25): # in case we miss an event + self._job_completed_event.clear() + if timeout > 0 and time.time() - start > timeout: + raise TimeoutError("Timeout exceeded") + return job + def _start_workers(self, max_workers: int) -> None: """Start the requested number of worker threads.""" self._stop_event.clear() @@ -223,6 +235,7 @@ def _download_next_item(self) -> None: finally: job.job_ended = get_iso_timestamp() + self._job_completed_event.set() # signal a change to terminal state self._queue.task_done() self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.") @@ -407,11 +420,11 @@ def _cleanup_cancelled_job(self, job: DownloadJob) -> None: # Example on_progress event handler to display a TQDM status bar # Activate with: -# download_service.download('http://foo.bar/baz', '/tmp', on_progress=TqdmProgress().job_update +# download_service.download(DownloadJob('http://foo.bar/baz', '/tmp', on_progress=TqdmProgress().update)) class TqdmProgress(object): """TQDM-based progress bar object to use in on_progress handlers.""" - _bars: Dict[int, tqdm] # the tqdm object + _bars: Dict[int, tqdm] # type: ignore _last: Dict[int, int] # last bytes downloaded def __init__(self) -> None: # noqa D107 diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index e9365f33495..af6fe4923f7 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -11,8 +11,7 @@ SessionQueueStatus, ) from invokeai.app.util.misc import get_timestamp -from invokeai.backend.model_management.model_manager import ModelInfo -from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType +from invokeai.backend.model_manager import AnyModelConfig class EventServiceBase: @@ -171,10 +170,7 @@ def emit_model_load_started( queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: SubModelType, + model_config: AnyModelConfig, ) -> None: """Emitted when a model is requested""" self.__emit_queue_event( @@ -184,10 +180,7 @@ def emit_model_load_started( "queue_item_id": queue_item_id, "queue_batch_id": queue_batch_id, "graph_execution_state_id": graph_execution_state_id, - "model_name": model_name, - "base_model": base_model, - "model_type": model_type, - "submodel": submodel, + "model_config": model_config.model_dump(), }, ) @@ -197,11 +190,7 @@ def emit_model_load_completed( queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: SubModelType, - model_info: ModelInfo, + model_config: AnyModelConfig, ) -> None: """Emitted when a model is correctly loaded (returns model info)""" self.__emit_queue_event( @@ -211,13 +200,7 @@ def emit_model_load_completed( "queue_item_id": queue_item_id, "queue_batch_id": queue_batch_id, "graph_execution_state_id": graph_execution_state_id, - "model_name": model_name, - "base_model": base_model, - "model_type": model_type, - "submodel": submodel, - "hash": model_info.hash, - "location": str(model_info.location), - "precision": str(model_info.precision), + "model_config": model_config.model_dump(), }, ) diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 11a4de99d6e..aa3322a9a0a 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -22,9 +22,7 @@ from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase from .item_storage.item_storage_base import ItemStorageABC from .latents_storage.latents_storage_base import LatentsStorageBase - from .model_install import ModelInstallServiceBase from .model_manager.model_manager_base import ModelManagerServiceBase - from .model_records import ModelRecordServiceBase from .names.names_base import NameServiceBase from .session_processor.session_processor_base import SessionProcessorBase from .session_queue.session_queue_base import SessionQueueBase @@ -50,9 +48,7 @@ class InvocationServices: latents: "LatentsStorageBase" logger: "Logger" model_manager: "ModelManagerServiceBase" - model_records: "ModelRecordServiceBase" download_queue: "DownloadQueueServiceBase" - model_install: "ModelInstallServiceBase" processor: "InvocationProcessorABC" performance_statistics: "InvocationStatsServiceBase" queue: "InvocationQueueABC" @@ -78,9 +74,7 @@ def __init__( latents: "LatentsStorageBase", logger: "Logger", model_manager: "ModelManagerServiceBase", - model_records: "ModelRecordServiceBase", download_queue: "DownloadQueueServiceBase", - model_install: "ModelInstallServiceBase", processor: "InvocationProcessorABC", performance_statistics: "InvocationStatsServiceBase", queue: "InvocationQueueABC", @@ -104,9 +98,7 @@ def __init__( self.latents = latents self.logger = logger self.model_manager = model_manager - self.model_records = model_records self.download_queue = download_queue - self.model_install = model_install self.processor = processor self.performance_statistics = performance_statistics self.queue = queue diff --git a/invokeai/app/services/invocation_stats/invocation_stats_base.py b/invokeai/app/services/invocation_stats/invocation_stats_base.py index 22624a6579a..ec8a453323d 100644 --- a/invokeai/app/services/invocation_stats/invocation_stats_base.py +++ b/invokeai/app/services/invocation_stats/invocation_stats_base.py @@ -29,8 +29,8 @@ """ from abc import ABC, abstractmethod -from contextlib import AbstractContextManager from pathlib import Path +from typing import Iterator from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.services.invocation_stats.invocation_stats_common import InvocationStatsSummary @@ -40,18 +40,17 @@ class InvocationStatsServiceBase(ABC): "Abstract base class for recording node memory/time performance statistics" @abstractmethod - def __init__(self): + def __init__(self) -> None: """ Initialize the InvocationStatsService and reset counters to zero """ - pass @abstractmethod def collect_stats( self, invocation: BaseInvocation, graph_execution_state_id: str, - ) -> AbstractContextManager: + ) -> Iterator[None]: """ Return a context object that will capture the statistics on the execution of invocaation. Use with: to place around the part of the code that executes the invocation. @@ -61,7 +60,7 @@ def collect_stats( pass @abstractmethod - def reset_stats(self, graph_execution_state_id: str): + def reset_stats(self, graph_execution_state_id: str) -> None: """ Reset all statistics for the indicated graph. :param graph_execution_state_id: The id of the session whose stats to reset. @@ -70,7 +69,7 @@ def reset_stats(self, graph_execution_state_id: str): pass @abstractmethod - def log_stats(self, graph_execution_state_id: str): + def log_stats(self, graph_execution_state_id: str) -> None: """ Write out the accumulated statistics to the log or somewhere else. :param graph_execution_state_id: The id of the session whose stats to log. diff --git a/invokeai/app/services/invocation_stats/invocation_stats_default.py b/invokeai/app/services/invocation_stats/invocation_stats_default.py index be58aaad2dd..486a1ca5b3e 100644 --- a/invokeai/app/services/invocation_stats/invocation_stats_default.py +++ b/invokeai/app/services/invocation_stats/invocation_stats_default.py @@ -2,6 +2,7 @@ import time from contextlib import contextmanager from pathlib import Path +from typing import Iterator import psutil import torch @@ -10,7 +11,7 @@ from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.services.invoker import Invoker from invokeai.app.services.item_storage.item_storage_common import ItemNotFoundError -from invokeai.backend.model_management.model_cache import CacheStats +from invokeai.backend.model_manager.load.model_cache import CacheStats from .invocation_stats_base import InvocationStatsServiceBase from .invocation_stats_common import ( @@ -41,7 +42,10 @@ def start(self, invoker: Invoker) -> None: self._invoker = invoker @contextmanager - def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str): + def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str) -> Iterator[None]: + # This is to handle case of the model manager not being initialized, which happens + # during some tests. + services = self._invoker.services if not self._stats.get(graph_execution_state_id): # First time we're seeing this graph_execution_state_id. self._stats[graph_execution_state_id] = GraphExecutionStats() @@ -55,8 +59,9 @@ def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: st start_ram = psutil.Process().memory_info().rss if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats() - if self._invoker.services.model_manager: - self._invoker.services.model_manager.collect_cache_stats(self._cache_stats[graph_execution_state_id]) + + assert services.model_manager.load is not None + services.model_manager.load.ram_cache.stats = self._cache_stats[graph_execution_state_id] try: # Let the invocation run. @@ -73,7 +78,7 @@ def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: st ) self._stats[graph_execution_state_id].add_node_execution_stats(node_stats) - def _prune_stale_stats(self): + def _prune_stale_stats(self) -> None: """Check all graphs being tracked and prune any that have completed/errored. This shouldn't be necessary, but we don't have totally robust upstream handling of graph completions/errors, so diff --git a/invokeai/app/services/latents_storage/latents_storage_base.py b/invokeai/app/services/latents_storage/latents_storage_base.py index 9fa42b0ae61..25972591265 100644 --- a/invokeai/app/services/latents_storage/latents_storage_base.py +++ b/invokeai/app/services/latents_storage/latents_storage_base.py @@ -1,10 +1,12 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) from abc import ABC, abstractmethod -from typing import Callable +from typing import Callable, Union import torch +from invokeai.app.invocations.compel import ConditioningFieldData + class LatentsStorageBase(ABC): """Responsible for storing and retrieving latents.""" @@ -20,8 +22,10 @@ def __init__(self) -> None: def get(self, name: str) -> torch.Tensor: pass + # (LS) Added a Union with ConditioningFieldData to fix type mismatch errors in compel.py + # Not 100% sure this isn't an existing bug. @abstractmethod - def save(self, name: str, data: torch.Tensor) -> None: + def save(self, name: str, data: Union[torch.Tensor, ConditioningFieldData]) -> None: pass @abstractmethod diff --git a/invokeai/app/services/latents_storage/latents_storage_disk.py b/invokeai/app/services/latents_storage/latents_storage_disk.py index 9192b9147f7..cc94a25e5ae 100644 --- a/invokeai/app/services/latents_storage/latents_storage_disk.py +++ b/invokeai/app/services/latents_storage/latents_storage_disk.py @@ -5,6 +5,7 @@ import torch +from invokeai.app.invocations.compel import ConditioningFieldData from invokeai.app.services.invoker import Invoker from .latents_storage_base import LatentsStorageBase @@ -27,7 +28,7 @@ def get(self, name: str) -> torch.Tensor: latent_path = self.get_path(name) return torch.load(latent_path) - def save(self, name: str, data: torch.Tensor) -> None: + def save(self, name: str, data: Union[torch.Tensor, ConditioningFieldData]) -> None: self.__output_folder.mkdir(parents=True, exist_ok=True) latent_path = self.get_path(name) torch.save(data, latent_path) diff --git a/invokeai/app/services/latents_storage/latents_storage_forward_cache.py b/invokeai/app/services/latents_storage/latents_storage_forward_cache.py index 6232b76a27d..3a0322011da 100644 --- a/invokeai/app/services/latents_storage/latents_storage_forward_cache.py +++ b/invokeai/app/services/latents_storage/latents_storage_forward_cache.py @@ -1,10 +1,11 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) from queue import Queue -from typing import Dict, Optional +from typing import Dict, Optional, Union import torch +from invokeai.app.invocations.compel import ConditioningFieldData from invokeai.app.services.invoker import Invoker from .latents_storage_base import LatentsStorageBase @@ -46,7 +47,9 @@ def get(self, name: str) -> torch.Tensor: self.__set_cache(name, latent) return latent - def save(self, name: str, data: torch.Tensor) -> None: + # TODO: (LS) ConditioningFieldData added as Union because of type-checking errors + # in compel.py. Unclear whether this is a long-standing bug, but seems to run. + def save(self, name: str, data: Union[torch.Tensor, ConditioningFieldData]) -> None: self.__underlying_storage.save(name, data) self.__set_cache(name, data) self._on_changed(data) diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index 635cb154d64..080219af75e 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -14,11 +14,13 @@ from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase -from invokeai.app.services.events import EventServiceBase +from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invoker import Invoker from invokeai.app.services.model_records import ModelRecordServiceBase from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant -from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore +from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata + +from ..model_metadata import ModelMetadataStoreBase class InstallStatus(str, Enum): @@ -127,8 +129,8 @@ def proper_repo_id(cls, v: str) -> str: # noqa D102 def __str__(self) -> str: """Return string version of repoid when string rep needed.""" base: str = self.repo_id + base += f":{self.variant or ''}" base += f":{self.subfolder}" if self.subfolder else "" - base += f" ({self.variant})" if self.variant else "" return base @@ -243,7 +245,7 @@ def __init__( app_config: InvokeAIAppConfig, record_store: ModelRecordServiceBase, download_queue: DownloadQueueServiceBase, - metadata_store: ModelMetadataStore, + metadata_store: ModelMetadataStoreBase, event_bus: Optional["EventServiceBase"] = None, ): """ @@ -324,6 +326,43 @@ def install_path( :returns id: The string ID of the registered model. """ + @abstractmethod + def heuristic_import( + self, + source: str, + config: Optional[Dict[str, Any]] = None, + access_token: Optional[str] = None, + ) -> ModelInstallJob: + r"""Install the indicated model using heuristics to interpret user intentions. + + :param source: String source + :param config: Optional dict. Any fields in this dict + will override corresponding autoassigned probe fields in the + model's config record as described in `import_model()`. + :param access_token: Optional access token for remote sources. + + The source can be: + 1. A local file path in posix() format (`/foo/bar` or `C:\foo\bar`) + 2. An http or https URL (`https://foo.bar/foo`) + 3. A HuggingFace repo_id (`foo/bar`, `foo/bar:fp16`, `foo/bar:fp16:vae`) + + We extend the HuggingFace repo_id syntax to include the variant and the + subfolder or path. The following are acceptable alternatives: + stabilityai/stable-diffusion-v4 + stabilityai/stable-diffusion-v4:fp16 + stabilityai/stable-diffusion-v4:fp16:vae + stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors + stabilityai/stable-diffusion-v4:onnx:vae + + Because a local file path can look like a huggingface repo_id, the logic + first checks whether the path exists on disk, and if not, it is treated as + a parseable huggingface repo. + + The previous support for recursing into a local folder and loading all model-like files + has been removed. + """ + pass + @abstractmethod def import_model( self, @@ -385,6 +424,18 @@ def prune_jobs(self) -> None: def cancel_job(self, job: ModelInstallJob) -> None: """Cancel the indicated job.""" + @abstractmethod + def wait_for_job(self, job: ModelInstallJob, timeout: int = 0) -> ModelInstallJob: + """Wait for the indicated job to reach a terminal state. + + This will block until the indicated install job has completed, + been cancelled, or errored out. + + :param job: The job to wait on. + :param timeout: Wait up to indicated number of seconds. Raise a TimeoutError if + the job hasn't completed within the indicated time. + """ + @abstractmethod def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: """ @@ -394,7 +445,8 @@ def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: completed, been cancelled, or errored out. :param timeout: Wait up to indicated number of seconds. Raise an Exception('timeout') if - installs do not complete within the indicated time. + installs do not complete within the indicated time. A timeout of zero (the default) + will block indefinitely until the installs complete. """ @abstractmethod @@ -410,3 +462,22 @@ def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: @abstractmethod def sync_to_config(self) -> None: """Synchronize models on disk to those in the model record database.""" + + @abstractmethod + def download_and_cache(self, source: Union[str, AnyHttpUrl], access_token: Optional[str] = None) -> Path: + """ + Download the model file located at source to the models cache and return its Path. + + :param source: A Url or a string that can be converted into one. + :param access_token: Optional access token to access restricted resources. + + The model file will be downloaded into the system-wide model cache + (`models/.cache`) if it isn't already there. Note that the model cache + is periodically cleared of infrequently-used entries when the model + converter runs. + + Note that this doesn't automaticallly install or register the model, but is + intended for use by nodes that need access to models that aren't directly + supported by InvokeAI. The downloading process takes advantage of the download queue + to avoid interrupting other operations. + """ diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 82c667f584f..7dee8bfd8cb 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -17,10 +17,10 @@ from requests import Session from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase +from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invoker import Invoker -from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase, ModelRecordServiceSQL +from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase from invokeai.backend.model_manager.config import ( AnyModelConfig, BaseModelType, @@ -33,7 +33,6 @@ AnyModelRepoMetadata, CivitaiMetadataFetch, HuggingFaceMetadataFetch, - ModelMetadataStore, ModelMetadataWithFiles, RemoteModelFile, ) @@ -50,6 +49,7 @@ ModelInstallJob, ModelInstallServiceBase, ModelSource, + StringLikeSource, URLModelSource, ) @@ -64,7 +64,6 @@ def __init__( app_config: InvokeAIAppConfig, record_store: ModelRecordServiceBase, download_queue: DownloadQueueServiceBase, - metadata_store: Optional[ModelMetadataStore] = None, event_bus: Optional[EventServiceBase] = None, session: Optional[Session] = None, ): @@ -86,19 +85,13 @@ def __init__( self._lock = threading.Lock() self._stop_event = threading.Event() self._downloads_changed_event = threading.Event() + self._install_completed_event = threading.Event() self._download_queue = download_queue self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {} self._running = False self._session = session self._next_job_id = 0 - # There may not necessarily be a metadata store initialized - # so we create one and initialize it with the same sql database - # used by the record store service. - if metadata_store: - self._metadata_store = metadata_store - else: - assert isinstance(record_store, ModelRecordServiceSQL) - self._metadata_store = ModelMetadataStore(record_store.db) + self._metadata_store = record_store.metadata_store # for convenience @property def app_config(self) -> InvokeAIAppConfig: # noqa D102 @@ -145,7 +138,7 @@ def register_path( ) -> str: # noqa D102 model_path = Path(model_path) config = config or {} - if config.get("source") is None: + if not config.get("source"): config["source"] = model_path.resolve().as_posix() return self._register(model_path, config) @@ -156,12 +149,14 @@ def install_path( ) -> str: # noqa D102 model_path = Path(model_path) config = config or {} - if config.get("source") is None: + if not config.get("source"): config["source"] = model_path.resolve().as_posix() info: AnyModelConfig = self._probe_model(Path(model_path), config) - old_hash = info.original_hash - dest_path = self.app_config.models_path / info.base.value / info.type.value / model_path.name + old_hash = info.current_hash + dest_path = ( + self.app_config.models_path / info.base.value / info.type.value / (config.get("name") or model_path.name) + ) try: new_path = self._copy_model(model_path, dest_path) except FileExistsError as excp: @@ -177,7 +172,40 @@ def install_path( info, ) + def heuristic_import( + self, + source: str, + config: Optional[Dict[str, Any]] = None, + access_token: Optional[str] = None, + ) -> ModelInstallJob: + variants = "|".join(ModelRepoVariant.__members__.values()) + hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$" + source_obj: Optional[StringLikeSource] = None + + if Path(source).exists(): # A local file or directory + source_obj = LocalModelSource(path=Path(source)) + elif match := re.match(hf_repoid_re, source): + source_obj = HFModelSource( + repo_id=match.group(1), + variant=match.group(2) if match.group(2) else None, # pass None rather than '' + subfolder=Path(match.group(3)) if match.group(3) else None, + access_token=access_token, + ) + elif re.match(r"^https?://[^/]+", source): + source_obj = URLModelSource( + url=AnyHttpUrl(source), + access_token=access_token, + ) + else: + raise ValueError(f"Unsupported model source: '{source}'") + return self.import_model(source_obj, config) + def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102 + similar_jobs = [x for x in self.list_jobs() if x.source == source and not x.in_terminal_state] + if similar_jobs: + self._logger.warning(f"There is already an active install job for {source}. Not enqueuing.") + return similar_jobs[0] + if isinstance(source, LocalModelSource): install_job = self._import_local_model(source, config) self._install_queue.put(install_job) # synchronously install @@ -207,14 +235,25 @@ def get_job_by_id(self, id: int) -> ModelInstallJob: # noqa D102 assert isinstance(jobs[0], ModelInstallJob) return jobs[0] + def wait_for_job(self, job: ModelInstallJob, timeout: int = 0) -> ModelInstallJob: + """Block until the indicated job has reached terminal state, or when timeout limit reached.""" + start = time.time() + while not job.in_terminal_state: + if self._install_completed_event.wait(timeout=5): # in case we miss an event + self._install_completed_event.clear() + if timeout > 0 and time.time() - start > timeout: + raise TimeoutError("Timeout exceeded") + return job + + # TODO: Better name? Maybe wait_for_jobs()? Maybe too easily confused with above def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: # noqa D102 """Block until all installation jobs are done.""" start = time.time() while len(self._download_cache) > 0: - if self._downloads_changed_event.wait(timeout=5): # in case we miss an event + if self._downloads_changed_event.wait(timeout=0.25): # in case we miss an event self._downloads_changed_event.clear() if timeout > 0 and time.time() - start > timeout: - raise Exception("Timeout exceeded") + raise TimeoutError("Timeout exceeded") self._install_queue.join() return self._install_jobs @@ -268,6 +307,38 @@ def unconditionally_delete(self, key: str) -> None: # noqa D102 path.unlink() self.unregister(key) + def download_and_cache( + self, + source: Union[str, AnyHttpUrl], + access_token: Optional[str] = None, + timeout: int = 0, + ) -> Path: + """Download the model file located at source to the models cache and return its Path.""" + model_hash = sha256(str(source).encode("utf-8")).hexdigest()[0:32] + model_path = self._app_config.models_convert_cache_path / model_hash + + # We expect the cache directory to contain one and only one downloaded file. + # We don't know the file's name in advance, as it is set by the download + # content-disposition header. + if model_path.exists(): + contents = [x for x in model_path.iterdir() if x.is_file()] + if len(contents) > 0: + return contents[0] + + model_path.mkdir(parents=True, exist_ok=True) + job = self._download_queue.download( + source=AnyHttpUrl(str(source)), + dest=model_path, + access_token=access_token, + on_progress=TqdmProgress().update, + ) + self._download_queue.wait_for_job(job, timeout) + if job.complete: + assert job.download_path is not None + return job.download_path + else: + raise Exception(job.error) + # -------------------------------------------------------------------------------------------- # Internal functions that manage the installer threads # -------------------------------------------------------------------------------------------- @@ -300,6 +371,7 @@ def _install_next_item(self) -> None: job.total_bytes = self._stat_size(job.local_path) job.bytes = job.total_bytes self._signal_job_running(job) + job.config_in["source"] = str(job.source) if job.inplace: key = self.register_path(job.local_path, job.config_in) else: @@ -330,6 +402,7 @@ def _install_next_item(self) -> None: # if this is an install of a remote file, then clean up the temporary directory if job._install_tmpdir is not None: rmtree(job._install_tmpdir) + self._install_completed_event.set() self._install_queue.task_done() self._logger.info("Install thread exiting") @@ -489,10 +562,10 @@ def _next_id(self) -> int: return id @staticmethod - def _guess_variant() -> ModelRepoVariant: + def _guess_variant() -> Optional[ModelRepoVariant]: """Guess the best HuggingFace variant type to download.""" precision = choose_precision(choose_torch_device()) - return ModelRepoVariant.FP16 if precision == "float16" else ModelRepoVariant.DEFAULT + return ModelRepoVariant.FP16 if precision == "float16" else None def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob: return ModelInstallJob( @@ -517,7 +590,7 @@ def _import_from_hf(self, source: HFModelSource, config: Optional[Dict[str, Any] if not source.access_token: self._logger.info("No HuggingFace access token present; some models may not be downloadable.") - metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id) + metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant) assert isinstance(metadata, ModelMetadataWithFiles) remote_files = metadata.download_urls( variant=source.variant or self._guess_variant(), @@ -565,6 +638,8 @@ def _import_remote_model( # TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up. # Currently the tmpdir isn't automatically removed at exit because it is # being held in a daemon thread. + if len(remote_files) == 0: + raise ValueError(f"{source}: No downloadable files found") tmpdir = Path( mkdtemp( dir=self._app_config.models_path, @@ -580,6 +655,16 @@ def _import_remote_model( bytes=0, total_bytes=0, ) + # In the event that there is a subfolder specified in the source, + # we need to remove it from the destination path in order to avoid + # creating unwanted subfolders + if hasattr(source, "subfolder") and source.subfolder: + root = Path(remote_files[0].path.parts[0]) + subfolder = root / source.subfolder + else: + root = Path(".") + subfolder = Path(".") + # we remember the path up to the top of the tmpdir so that it may be # removed safely at the end of the install process. install_job._install_tmpdir = tmpdir @@ -589,7 +674,7 @@ def _import_remote_model( self._logger.debug(f"remote_files={remote_files}") for model_file in remote_files: url = model_file.url - path = model_file.path + path = root / model_file.path.relative_to(subfolder) self._logger.info(f"Downloading {url} => {path}") install_job.total_bytes += model_file.size assert hasattr(source, "access_token") diff --git a/invokeai/app/services/model_load/__init__.py b/invokeai/app/services/model_load/__init__.py new file mode 100644 index 00000000000..b4a86e9348d --- /dev/null +++ b/invokeai/app/services/model_load/__init__.py @@ -0,0 +1,6 @@ +"""Initialization file for model load service module.""" + +from .model_load_base import ModelLoadServiceBase +from .model_load_default import ModelLoadService + +__all__ = ["ModelLoadServiceBase", "ModelLoadService"] diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py new file mode 100644 index 00000000000..cd584395f56 --- /dev/null +++ b/invokeai/app/services/model_load/model_load_base.py @@ -0,0 +1,40 @@ +# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team +"""Base class for model loader.""" + +from abc import ABC, abstractmethod +from typing import Optional + +from invokeai.app.invocations.baseinvocation import InvocationContext +from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType +from invokeai.backend.model_manager.load import LoadedModel +from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase +from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase + + +class ModelLoadServiceBase(ABC): + """Wrapper around AnyModelLoader.""" + + @abstractmethod + def load_model( + self, + model_config: AnyModelConfig, + submodel_type: Optional[SubModelType] = None, + context: Optional[InvocationContext] = None, + ) -> LoadedModel: + """ + Given a model's configuration, load it and return the LoadedModel object. + + :param model_config: Model configuration record (as returned by ModelRecordBase.get_model()) + :param submodel: For main (pipeline models), the submodel to fetch. + :param context: Invocation context used for event reporting + """ + + @property + @abstractmethod + def ram_cache(self) -> ModelCacheBase[AnyModel]: + """Return the RAM cache used by this loader.""" + + @property + @abstractmethod + def convert_cache(self) -> ModelConvertCacheBase: + """Return the checkpoint convert cache used by this loader.""" diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py new file mode 100644 index 00000000000..e020ba9d1cd --- /dev/null +++ b/invokeai/app/services/model_load/model_load_default.py @@ -0,0 +1,106 @@ +# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team +"""Implementation of model loader service.""" + +from typing import Optional, Type + +from invokeai.app.invocations.baseinvocation import InvocationContext +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException +from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType +from invokeai.backend.model_manager.load import LoadedModel, ModelLoaderRegistry, ModelLoaderRegistryBase +from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase +from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase +from invokeai.backend.util.logging import InvokeAILogger + +from .model_load_base import ModelLoadServiceBase + + +class ModelLoadService(ModelLoadServiceBase): + """Wrapper around ModelLoaderRegistry.""" + + def __init__( + self, + app_config: InvokeAIAppConfig, + ram_cache: ModelCacheBase[AnyModel], + convert_cache: ModelConvertCacheBase, + registry: Optional[Type[ModelLoaderRegistryBase]] = ModelLoaderRegistry, + ): + """Initialize the model load service.""" + logger = InvokeAILogger.get_logger(self.__class__.__name__) + logger.setLevel(app_config.log_level.upper()) + self._logger = logger + self._app_config = app_config + self._ram_cache = ram_cache + self._convert_cache = convert_cache + self._registry = registry + + @property + def ram_cache(self) -> ModelCacheBase[AnyModel]: + """Return the RAM cache used by this loader.""" + return self._ram_cache + + @property + def convert_cache(self) -> ModelConvertCacheBase: + """Return the checkpoint convert cache used by this loader.""" + return self._convert_cache + + def load_model( + self, + model_config: AnyModelConfig, + submodel_type: Optional[SubModelType] = None, + context: Optional[InvocationContext] = None, + ) -> LoadedModel: + """ + Given a model's configuration, load it and return the LoadedModel object. + + :param model_config: Model configuration record (as returned by ModelRecordBase.get_model()) + :param submodel: For main (pipeline models), the submodel to fetch. + :param context: Invocation context used for event reporting + """ + if context: + self._emit_load_event( + context=context, + model_config=model_config, + ) + + implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore + loaded_model: LoadedModel = implementation( + app_config=self._app_config, + logger=self._logger, + ram_cache=self._ram_cache, + convert_cache=self._convert_cache, + ).load_model(model_config, submodel_type) + + if context: + self._emit_load_event( + context=context, + model_config=model_config, + loaded=True, + ) + return loaded_model + + def _emit_load_event( + self, + context: InvocationContext, + model_config: AnyModelConfig, + loaded: Optional[bool] = False, + ) -> None: + if context.services.queue.is_canceled(context.graph_execution_state_id): + raise CanceledException() + + if not loaded: + context.services.events.emit_model_load_started( + queue_id=context.queue_id, + queue_item_id=context.queue_item_id, + queue_batch_id=context.queue_batch_id, + graph_execution_state_id=context.graph_execution_state_id, + model_config=model_config, + ) + else: + context.services.events.emit_model_load_completed( + queue_id=context.queue_id, + queue_item_id=context.queue_item_id, + queue_batch_id=context.queue_batch_id, + graph_execution_state_id=context.graph_execution_state_id, + model_config=model_config, + ) diff --git a/invokeai/app/services/model_manager/__init__.py b/invokeai/app/services/model_manager/__init__.py index 3d6a9c248c6..5455577266a 100644 --- a/invokeai/app/services/model_manager/__init__.py +++ b/invokeai/app/services/model_manager/__init__.py @@ -1 +1,17 @@ -from .model_manager_default import ModelManagerService # noqa F401 +"""Initialization file for model manager service.""" + +from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType +from invokeai.backend.model_manager.load import LoadedModel + +from .model_manager_default import ModelManagerService, ModelManagerServiceBase + +__all__ = [ + "ModelManagerServiceBase", + "ModelManagerService", + "AnyModel", + "AnyModelConfig", + "BaseModelType", + "ModelType", + "SubModelType", + "LoadedModel", +] diff --git a/invokeai/app/services/model_manager/model_manager_base.py b/invokeai/app/services/model_manager/model_manager_base.py index 4c2fc4c085c..1116c82ff1f 100644 --- a/invokeai/app/services/model_manager/model_manager_base.py +++ b/invokeai/app/services/model_manager/model_manager_base.py @@ -1,286 +1,67 @@ # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team -from __future__ import annotations - from abc import ABC, abstractmethod -from logging import Logger -from pathlib import Path -from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union -from pydantic import Field +from typing_extensions import Self -from invokeai.app.services.config.config_default import InvokeAIAppConfig -from invokeai.backend.model_management import ( - AddModelResult, - BaseModelType, - MergeInterpolationMethod, - ModelInfo, - ModelType, - SchedulerPredictionType, - SubModelType, -) -from invokeai.backend.model_management.model_cache import CacheStats +from invokeai.app.services.invoker import Invoker -if TYPE_CHECKING: - from invokeai.app.invocations.baseinvocation import BaseInvocation, InvocationContext +from ..config import InvokeAIAppConfig +from ..download import DownloadQueueServiceBase +from ..events.events_base import EventServiceBase +from ..model_install import ModelInstallServiceBase +from ..model_load import ModelLoadServiceBase +from ..model_records import ModelRecordServiceBase +from ..shared.sqlite.sqlite_database import SqliteDatabase class ModelManagerServiceBase(ABC): - """Responsible for managing models on disk and in memory""" - - @abstractmethod - def __init__( - self, - config: InvokeAIAppConfig, - logger: Logger, - ): - """ - Initialize with the path to the models.yaml config file. - Optional parameters are the torch device type, precision, max_models, - and sequential_offload boolean. Note that the default device - type and precision are set up for a CUDA system running at half precision. - """ - pass - - @abstractmethod - def get_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: Optional[SubModelType] = None, - node: Optional[BaseInvocation] = None, - context: Optional[InvocationContext] = None, - ) -> ModelInfo: - """Retrieve the indicated model with name and type. - submodel can be used to get a part (such as the vae) - of a diffusers pipeline.""" - pass - - @property - @abstractmethod - def logger(self): - pass - - @abstractmethod - def model_exists( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - ) -> bool: - pass - - @abstractmethod - def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: - """ - Given a model name returns a dict-like (OmegaConf) object describing it. - Uses the exact format as the omegaconf stanza. - """ - pass - - @abstractmethod - def list_models(self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None) -> dict: - """ - Return a dict of models in the format: - { model_type1: - { model_name1: {'status': 'active'|'cached'|'not loaded', - 'model_name' : name, - 'model_type' : SDModelType, - 'description': description, - 'format': 'folder'|'safetensors'|'ckpt' - }, - model_name2: { etc } - }, - model_type2: - { model_name_n: etc - } - """ - pass - - @abstractmethod - def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: - """ - Return information about the model using the same format as list_models() - """ - pass - - @abstractmethod - def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]: - """ - Returns a list of all the model names known. - """ - pass - - @abstractmethod - def add_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - model_attributes: dict, - clobber: bool = False, - ) -> AddModelResult: - """ - Update the named model with a dictionary of attributes. Will fail with an - assertion error if the name already exists. Pass clobber=True to overwrite. - On a successful update, the config will be changed in memory. Will fail - with an assertion error if provided attributes are incorrect or - the model name is missing. Call commit() to write changes to disk. - """ - pass - - @abstractmethod - def update_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - model_attributes: dict, - ) -> AddModelResult: - """ - Update the named model with a dictionary of attributes. Will fail with a - ModelNotFoundException if the name does not already exist. - - On a successful update, the config will be changed in memory. Will fail - with an assertion error if provided attributes are incorrect or - the model name is missing. Call commit() to write changes to disk. - """ - pass - - @abstractmethod - def del_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - ): - """ - Delete the named model from configuration. If delete_files is true, - then the underlying weight file or diffusers directory will be deleted - as well. Call commit() to write to disk. - """ - pass - - @abstractmethod - def rename_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - new_name: str, - ): - """ - Rename the indicated model. - """ - pass + """Abstract base class for the model manager service.""" - @abstractmethod - def list_checkpoint_configs(self) -> List[Path]: - """ - List the checkpoint config paths from ROOT/configs/stable-diffusion. - """ - pass + # attributes: + # store: ModelRecordServiceBase = Field(description="An instance of the model record configuration service.") + # install: ModelInstallServiceBase = Field(description="An instance of the model install service.") + # load: ModelLoadServiceBase = Field(description="An instance of the model load service.") + @classmethod @abstractmethod - def convert_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: Literal[ModelType.Main, ModelType.Vae], - ) -> AddModelResult: - """ - Convert a checkpoint file into a diffusers folder, deleting the cached - version and deleting the original checkpoint file if it is in the models - directory. - :param model_name: Name of the model to convert - :param base_model: Base model type - :param model_type: Type of model ['vae' or 'main'] - - This will raise a ValueError unless the model is not a checkpoint. It will - also raise a ValueError in the event that there is a similarly-named diffusers - directory already in place. + def build_model_manager( + cls, + app_config: InvokeAIAppConfig, + db: SqliteDatabase, + download_queue: DownloadQueueServiceBase, + events: EventServiceBase, + ) -> Self: """ - pass - - @abstractmethod - def heuristic_import( - self, - items_to_import: set[str], - prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None, - ) -> dict[str, AddModelResult]: - """Import a list of paths, repo_ids or URLs. Returns the set of - successfully imported items. - :param items_to_import: Set of strings corresponding to models to be imported. - :param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType. - - The prediction type helper is necessary to distinguish between - models based on Stable Diffusion 2 Base (requiring - SchedulerPredictionType.Epsilson) and Stable Diffusion 768 - (requiring SchedulerPredictionType.VPrediction). It is - generally impossible to do this programmatically, so the - prediction_type_helper usually asks the user to choose. + Construct the model manager service instance. - The result is a set of successfully installed models. Each element - of the set is a dict corresponding to the newly-created OmegaConf stanza for - that model. + Use it rather than the __init__ constructor. This class + method simplifies the construction considerably. """ pass + @property @abstractmethod - def merge_models( - self, - model_names: List[str] = Field( - default=None, min_length=2, max_length=3, description="List of model names to merge" - ), - base_model: Union[BaseModelType, str] = Field( - default=None, description="Base model shared by all models to be merged" - ), - merged_model_name: str = Field(default=None, description="Name of destination model after merging"), - alpha: Optional[float] = 0.5, - interp: Optional[MergeInterpolationMethod] = None, - force: Optional[bool] = False, - merge_dest_directory: Optional[Path] = None, - ) -> AddModelResult: - """ - Merge two to three diffusrs pipeline models and save as a new model. - :param model_names: List of 2-3 models to merge - :param base_model: Base model to use for all models - :param merged_model_name: Name of destination merged model - :param alpha: Alpha strength to apply to 2d and 3d model - :param interp: Interpolation method. None (default) - :param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended) - """ + def store(self) -> ModelRecordServiceBase: + """Return the ModelRecordServiceBase used to store and retrieve configuration records.""" pass + @property @abstractmethod - def search_for_models(self, directory: Path) -> List[Path]: - """ - Return list of all models found in the designated directory. - """ + def load(self) -> ModelLoadServiceBase: + """Return the ModelLoadServiceBase used to load models from their configuration records.""" pass + @property @abstractmethod - def sync_to_config(self): - """ - Re-read models.yaml, rescan the models directory, and reimport models - in the autoimport directories. Call after making changes outside the - model manager API. - """ + def install(self) -> ModelInstallServiceBase: + """Return the ModelInstallServiceBase used to download and manipulate model files.""" pass @abstractmethod - def collect_cache_stats(self, cache_stats: CacheStats): - """ - Reset model cache statistics for graph with graph_id. - """ + def start(self, invoker: Invoker) -> None: pass @abstractmethod - def commit(self, conf_file: Optional[Path] = None) -> None: - """ - Write current configuration out to the indicated file. - If no conf_file is provided, then replaces the - original file/database used to initialize the object. - """ + def stop(self, invoker: Invoker) -> None: pass diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index cdb3e59a91c..0f4d1a7ecf4 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -1,413 +1,149 @@ # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team +"""Implementation of ModelManagerServiceBase.""" -from __future__ import annotations +from typing import Optional -from logging import Logger -from pathlib import Path -from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union +from typing_extensions import Self -import torch -from pydantic import Field - -from invokeai.app.services.config.config_default import InvokeAIAppConfig -from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException -from invokeai.backend.model_management import ( - AddModelResult, - BaseModelType, - MergeInterpolationMethod, - ModelInfo, - ModelManager, - ModelMerger, - ModelNotFoundException, - ModelType, - SchedulerPredictionType, - SubModelType, -) -from invokeai.backend.model_management.model_cache import CacheStats -from invokeai.backend.model_management.model_search import FindModels -from invokeai.backend.util import choose_precision, choose_torch_device +from invokeai.app.invocations.baseinvocation import InvocationContext +from invokeai.app.services.invoker import Invoker +from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, LoadedModel, ModelType, SubModelType +from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry +from invokeai.backend.util.logging import InvokeAILogger +from ..config import InvokeAIAppConfig +from ..download import DownloadQueueServiceBase +from ..events.events_base import EventServiceBase +from ..model_install import ModelInstallService, ModelInstallServiceBase +from ..model_load import ModelLoadService, ModelLoadServiceBase +from ..model_records import ModelRecordServiceBase, UnknownModelException from .model_manager_base import ModelManagerServiceBase -if TYPE_CHECKING: - from invokeai.app.invocations.baseinvocation import InvocationContext - -# simple implementation class ModelManagerService(ModelManagerServiceBase): - """Responsible for managing models on disk and in memory""" + """ + The ModelManagerService handles various aspects of model installation, maintenance and loading. + + It bundles three distinct services: + model_manager.store -- Routines to manage the database of model configuration records. + model_manager.install -- Routines to install, move and delete models. + model_manager.load -- Routines to load models into memory. + """ def __init__( self, - config: InvokeAIAppConfig, - logger: Logger, + store: ModelRecordServiceBase, + install: ModelInstallServiceBase, + load: ModelLoadServiceBase, ): - """ - Initialize with the path to the models.yaml config file. - Optional parameters are the torch device type, precision, max_models, - and sequential_offload boolean. Note that the default device - type and precision are set up for a CUDA system running at half precision. - """ - if config.model_conf_path and config.model_conf_path.exists(): - config_file = config.model_conf_path - else: - config_file = config.root_dir / "configs/models.yaml" - - logger.debug(f"Config file={config_file}") + self._store = store + self._install = install + self._load = load - device = torch.device(choose_torch_device()) - device_name = torch.cuda.get_device_name() if device == torch.device("cuda") else "" - logger.info(f"GPU device = {device} {device_name}") - - precision = config.precision - if precision == "auto": - precision = choose_precision(device) - dtype = torch.float32 if precision == "float32" else torch.float16 + @property + def store(self) -> ModelRecordServiceBase: + return self._store - # this is transitional backward compatibility - # support for the deprecated `max_loaded_models` - # configuration value. If present, then the - # cache size is set to 2.5 GB times - # the number of max_loaded_models. Otherwise - # use new `ram_cache_size` config setting - max_cache_size = config.ram_cache_size + @property + def install(self) -> ModelInstallServiceBase: + return self._install - logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB") + @property + def load(self) -> ModelLoadServiceBase: + return self._load - sequential_offload = config.sequential_guidance + def start(self, invoker: Invoker) -> None: + for service in [self._store, self._install, self._load]: + if hasattr(service, "start"): + service.start(invoker) - self.mgr = ModelManager( - config=config_file, - device_type=device, - precision=dtype, - max_cache_size=max_cache_size, - sequential_offload=sequential_offload, - logger=logger, - ) - logger.info("Model manager service initialized") + def stop(self, invoker: Invoker) -> None: + for service in [self._store, self._install, self._load]: + if hasattr(service, "stop"): + service.stop(invoker) - def get_model( + def load_model_by_config( self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: Optional[SubModelType] = None, + model_config: AnyModelConfig, + submodel_type: Optional[SubModelType] = None, context: Optional[InvocationContext] = None, - ) -> ModelInfo: - """ - Retrieve the indicated model. submodel can be used to get a - part (such as the vae) of a diffusers mode. - """ - - # we can emit model loading events if we are executing with access to the invocation context - if context: - self._emit_load_event( - context=context, - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=submodel, - ) - - model_info = self.mgr.get_model( - model_name, - base_model, - model_type, - submodel, - ) - - if context: - self._emit_load_event( - context=context, - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=submodel, - model_info=model_info, - ) - - return model_info - - def model_exists( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - ) -> bool: - """ - Given a model name, returns True if it is a valid - identifier. - """ - return self.mgr.model_exists( - model_name, - base_model, - model_type, - ) - - def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]: - """ - Given a model name returns a dict-like (OmegaConf) object describing it. - """ - return self.mgr.model_info(model_name, base_model, model_type) - - def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]: - """ - Returns a list of all the model names known. - """ - return self.mgr.model_names() - - def list_models( - self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None - ) -> list[dict]: - """ - Return a list of models. - """ - return self.mgr.list_models(base_model, model_type) - - def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]: - """ - Return information about the model using the same format as list_models() - """ - return self.mgr.list_model(model_name=model_name, base_model=base_model, model_type=model_type) + ) -> LoadedModel: + return self.load.load_model(model_config, submodel_type, context) - def add_model( + def load_model_by_key( self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - model_attributes: dict, - clobber: bool = False, - ) -> AddModelResult: - """ - Update the named model with a dictionary of attributes. Will fail with an - assertion error if the name already exists. Pass clobber=True to overwrite. - On a successful update, the config will be changed in memory. Will fail - with an assertion error if provided attributes are incorrect or - the model name is missing. Call commit() to write changes to disk. - """ - self.logger.debug(f"add/update model {model_name}") - return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber) - - def update_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - model_attributes: dict, - ) -> AddModelResult: - """ - Update the named model with a dictionary of attributes. Will fail with a - ModelNotFoundException exception if the name does not already exist. - On a successful update, the config will be changed in memory. Will fail - with an assertion error if provided attributes are incorrect or - the model name is missing. Call commit() to write changes to disk. - """ - self.logger.debug(f"update model {model_name}") - if not self.model_exists(model_name, base_model, model_type): - raise ModelNotFoundException(f"Unknown model {model_name}") - return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True) + key: str, + submodel_type: Optional[SubModelType] = None, + context: Optional[InvocationContext] = None, + ) -> LoadedModel: + config = self.store.get_model(key) + return self.load.load_model(config, submodel_type, context) - def del_model( + def load_model_by_attr( self, model_name: str, base_model: BaseModelType, model_type: ModelType, - ): - """ - Delete the named model from configuration. If delete_files is true, - then the underlying weight file or diffusers directory will be deleted - as well. - """ - self.logger.debug(f"delete model {model_name}") - self.mgr.del_model(model_name, base_model, model_type) - self.mgr.commit() - - def convert_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: Literal[ModelType.Main, ModelType.Vae], - convert_dest_directory: Optional[Path] = Field( - default=None, description="Optional directory location for merged model" - ), - ) -> AddModelResult: + submodel: Optional[SubModelType] = None, + context: Optional[InvocationContext] = None, + ) -> LoadedModel: """ - Convert a checkpoint file into a diffusers folder, deleting the cached - version and deleting the original checkpoint file if it is in the models - directory. - :param model_name: Name of the model to convert - :param base_model: Base model type - :param model_type: Type of model ['vae' or 'main'] - :param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default) + Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object. - This will raise a ValueError unless the model is not a checkpoint. It will - also raise a ValueError in the event that there is a similarly-named diffusers - directory already in place. - """ - self.logger.debug(f"convert model {model_name}") - return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory) + This is provided for API compatability with the get_model() method + in the original model manager. However, note that LoadedModel is + not the same as the original ModelInfo that ws returned. - def collect_cache_stats(self, cache_stats: CacheStats): - """ - Reset model cache statistics for graph with graph_id. - """ - self.mgr.cache.stats = cache_stats + :param model_name: Name of to be fetched. + :param base_model: Base model + :param model_type: Type of the model + :param submodel: For main (pipeline models), the submodel to fetch + :param context: The invocation context. - def commit(self, conf_file: Optional[Path] = None): - """ - Write current configuration out to the indicated file. - If no conf_file is provided, then replaces the - original file/database used to initialize the object. + Exceptions: UnknownModelException -- model with this key not known + NotImplementedException -- a model loader was not provided at initialization time + ValueError -- more than one model matches this combination """ - return self.mgr.commit(conf_file) - - def _emit_load_event( - self, - context: InvocationContext, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: Optional[SubModelType] = None, - model_info: Optional[ModelInfo] = None, - ): - if context.services.queue.is_canceled(context.graph_execution_state_id): - raise CanceledException() - - if model_info: - context.services.events.emit_model_load_completed( - queue_id=context.queue_id, - queue_item_id=context.queue_item_id, - queue_batch_id=context.queue_batch_id, - graph_execution_state_id=context.graph_execution_state_id, - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=submodel, - model_info=model_info, - ) + configs = self.store.search_by_attr(model_name, base_model, model_type) + if len(configs) == 0: + raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model") + elif len(configs) > 1: + raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.") else: - context.services.events.emit_model_load_started( - queue_id=context.queue_id, - queue_item_id=context.queue_item_id, - queue_batch_id=context.queue_batch_id, - graph_execution_state_id=context.graph_execution_state_id, - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=submodel, - ) + return self.load.load_model(configs[0], submodel, context) - @property - def logger(self): - return self.mgr.logger - - def heuristic_import( - self, - items_to_import: set[str], - prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None, - ) -> dict[str, AddModelResult]: - """Import a list of paths, repo_ids or URLs. Returns the set of - successfully imported items. - :param items_to_import: Set of strings corresponding to models to be imported. - :param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType. - - The prediction type helper is necessary to distinguish between - models based on Stable Diffusion 2 Base (requiring - SchedulerPredictionType.Epsilson) and Stable Diffusion 768 - (requiring SchedulerPredictionType.VPrediction). It is - generally impossible to do this programmatically, so the - prediction_type_helper usually asks the user to choose. - - The result is a set of successfully installed models. Each element - of the set is a dict corresponding to the newly-created OmegaConf stanza for - that model. + @classmethod + def build_model_manager( + cls, + app_config: InvokeAIAppConfig, + model_record_service: ModelRecordServiceBase, + download_queue: DownloadQueueServiceBase, + events: EventServiceBase, + ) -> Self: """ - return self.mgr.heuristic_import(items_to_import, prediction_type_helper) + Construct the model manager service instance. - def merge_models( - self, - model_names: List[str] = Field( - default=None, min_length=2, max_length=3, description="List of model names to merge" - ), - base_model: Union[BaseModelType, str] = Field( - default=None, description="Base model shared by all models to be merged" - ), - merged_model_name: str = Field(default=None, description="Name of destination model after merging"), - alpha: float = 0.5, - interp: Optional[MergeInterpolationMethod] = None, - force: bool = False, - merge_dest_directory: Optional[Path] = Field( - default=None, description="Optional directory location for merged model" - ), - ) -> AddModelResult: - """ - Merge two to three diffusrs pipeline models and save as a new model. - :param model_names: List of 2-3 models to merge - :param base_model: Base model to use for all models - :param merged_model_name: Name of destination merged model - :param alpha: Alpha strength to apply to 2d and 3d model - :param interp: Interpolation method. None (default) - :param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended) + For simplicity, use this class method rather than the __init__ constructor. """ - merger = ModelMerger(self.mgr) - try: - result = merger.merge_diffusion_models_and_save( - model_names=model_names, - base_model=base_model, - merged_model_name=merged_model_name, - alpha=alpha, - interp=interp, - force=force, - merge_dest_directory=merge_dest_directory, - ) - except AssertionError as e: - raise ValueError(e) - return result + logger = InvokeAILogger.get_logger(cls.__name__) + logger.setLevel(app_config.log_level.upper()) - def search_for_models(self, directory: Path) -> List[Path]: - """ - Return list of all models found in the designated directory. - """ - search = FindModels([directory], self.logger) - return search.list_models() - - def sync_to_config(self): - """ - Re-read models.yaml, rescan the models directory, and reimport models - in the autoimport directories. Call after making changes outside the - model manager API. - """ - return self.mgr.sync_to_config() - - def list_checkpoint_configs(self) -> List[Path]: - """ - List the checkpoint config paths from ROOT/configs/stable-diffusion. - """ - config = self.mgr.app_config - conf_path = config.legacy_conf_path - root_path = config.root_path - return [(conf_path / x).relative_to(root_path) for x in conf_path.glob("**/*.yaml")] - - def rename_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - new_name: Optional[str] = None, - new_base: Optional[BaseModelType] = None, - ): - """ - Rename the indicated model. Can provide a new name and/or a new base. - :param model_name: Current name of the model - :param base_model: Current base of the model - :param model_type: Model type (can't be changed) - :param new_name: New name for the model - :param new_base: New base for the model - """ - self.mgr.rename_model( - base_model=base_model, - model_type=model_type, - model_name=model_name, - new_name=new_name, - new_base=new_base, + ram_cache = ModelCache( + max_cache_size=app_config.ram_cache_size, max_vram_cache_size=app_config.vram_cache_size, logger=logger + ) + convert_cache = ModelConvertCache( + cache_path=app_config.models_convert_cache_path, max_size=app_config.convert_cache_size + ) + loader = ModelLoadService( + app_config=app_config, + ram_cache=ram_cache, + convert_cache=convert_cache, + registry=ModelLoaderRegistry, + ) + installer = ModelInstallService( + app_config=app_config, + record_store=model_record_service, + download_queue=download_queue, + event_bus=events, ) + return cls(store=model_record_service, install=installer, load=loader) diff --git a/invokeai/app/services/model_metadata/__init__.py b/invokeai/app/services/model_metadata/__init__.py new file mode 100644 index 00000000000..981c96b709b --- /dev/null +++ b/invokeai/app/services/model_metadata/__init__.py @@ -0,0 +1,9 @@ +"""Init file for ModelMetadataStoreService module.""" + +from .metadata_store_base import ModelMetadataStoreBase +from .metadata_store_sql import ModelMetadataStoreSQL + +__all__ = [ + "ModelMetadataStoreBase", + "ModelMetadataStoreSQL", +] diff --git a/invokeai/app/services/model_metadata/metadata_store_base.py b/invokeai/app/services/model_metadata/metadata_store_base.py new file mode 100644 index 00000000000..e0e4381b099 --- /dev/null +++ b/invokeai/app/services/model_metadata/metadata_store_base.py @@ -0,0 +1,65 @@ +# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team +""" +Storage for Model Metadata +""" + +from abc import ABC, abstractmethod +from typing import List, Set, Tuple + +from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata + + +class ModelMetadataStoreBase(ABC): + """Store, search and fetch model metadata retrieved from remote repositories.""" + + @abstractmethod + def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None: + """ + Add a block of repo metadata to a model record. + + The model record config must already exist in the database with the + same key. Otherwise a FOREIGN KEY constraint exception will be raised. + + :param model_key: Existing model key in the `model_config` table + :param metadata: ModelRepoMetadata object to store + """ + + @abstractmethod + def get_metadata(self, model_key: str) -> AnyModelRepoMetadata: + """Retrieve the ModelRepoMetadata corresponding to model key.""" + + @abstractmethod + def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata + """Dump out all the metadata.""" + + @abstractmethod + def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata: + """ + Update metadata corresponding to the model with the indicated key. + + :param model_key: Existing model key in the `model_config` table + :param metadata: ModelRepoMetadata object to update + """ + + @abstractmethod + def list_tags(self) -> Set[str]: + """Return all tags in the tags table.""" + + @abstractmethod + def search_by_tag(self, tags: Set[str]) -> Set[str]: + """Return the keys of models containing all of the listed tags.""" + + @abstractmethod + def search_by_author(self, author: str) -> Set[str]: + """Return the keys of models authored by the indicated author.""" + + @abstractmethod + def search_by_name(self, name: str) -> Set[str]: + """ + Return the keys of models with the indicated name. + + Note that this is the name of the model given to it by + the remote source. The user may have changed the local + name. The local name will be located in the model config + record object. + """ diff --git a/invokeai/app/services/model_metadata/metadata_store_sql.py b/invokeai/app/services/model_metadata/metadata_store_sql.py new file mode 100644 index 00000000000..afe9d2c8c69 --- /dev/null +++ b/invokeai/app/services/model_metadata/metadata_store_sql.py @@ -0,0 +1,222 @@ +# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team +""" +SQL Storage for Model Metadata +""" + +import sqlite3 +from typing import List, Optional, Set, Tuple + +from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase +from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException +from invokeai.backend.model_manager.metadata.fetch import ModelMetadataFetchBase + +from .metadata_store_base import ModelMetadataStoreBase + + +class ModelMetadataStoreSQL(ModelMetadataStoreBase): + """Store, search and fetch model metadata retrieved from remote repositories.""" + + def __init__(self, db: SqliteDatabase): + """ + Initialize a new object from preexisting sqlite3 connection and threading lock objects. + + :param conn: sqlite3 connection object + :param lock: threading Lock object + """ + super().__init__() + self._db = db + self._cursor = self._db.conn.cursor() + + def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None: + """ + Add a block of repo metadata to a model record. + + The model record config must already exist in the database with the + same key. Otherwise a FOREIGN KEY constraint exception will be raised. + + :param model_key: Existing model key in the `model_config` table + :param metadata: ModelRepoMetadata object to store + """ + json_serialized = metadata.model_dump_json() + with self._db.lock: + try: + self._cursor.execute( + """--sql + INSERT INTO model_metadata( + id, + metadata + ) + VALUES (?,?); + """, + ( + model_key, + json_serialized, + ), + ) + self._update_tags(model_key, metadata.tags) + self._db.conn.commit() + except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table + self._db.conn.rollback() + raise UnknownMetadataException from excp + except sqlite3.Error as excp: + self._db.conn.rollback() + raise excp + + def get_metadata(self, model_key: str) -> AnyModelRepoMetadata: + """Retrieve the ModelRepoMetadata corresponding to model key.""" + with self._db.lock: + self._cursor.execute( + """--sql + SELECT metadata FROM model_metadata + WHERE id=?; + """, + (model_key,), + ) + rows = self._cursor.fetchone() + if not rows: + raise UnknownMetadataException("model metadata not found") + return ModelMetadataFetchBase.from_json(rows[0]) + + def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata + """Dump out all the metadata.""" + with self._db.lock: + self._cursor.execute( + """--sql + SELECT id,metadata FROM model_metadata; + """, + (), + ) + rows = self._cursor.fetchall() + return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows] + + def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata: + """ + Update metadata corresponding to the model with the indicated key. + + :param model_key: Existing model key in the `model_config` table + :param metadata: ModelRepoMetadata object to update + """ + json_serialized = metadata.model_dump_json() # turn it into a json string. + with self._db.lock: + try: + self._cursor.execute( + """--sql + UPDATE model_metadata + SET + metadata=? + WHERE id=?; + """, + (json_serialized, model_key), + ) + if self._cursor.rowcount == 0: + raise UnknownMetadataException("model metadata not found") + self._update_tags(model_key, metadata.tags) + self._db.conn.commit() + except sqlite3.Error as e: + self._db.conn.rollback() + raise e + + return self.get_metadata(model_key) + + def list_tags(self) -> Set[str]: + """Return all tags in the tags table.""" + self._cursor.execute( + """--sql + select tag_text from tags; + """ + ) + return {x[0] for x in self._cursor.fetchall()} + + def search_by_tag(self, tags: Set[str]) -> Set[str]: + """Return the keys of models containing all of the listed tags.""" + with self._db.lock: + try: + matches: Optional[Set[str]] = None + for tag in tags: + self._cursor.execute( + """--sql + SELECT a.model_id FROM model_tags AS a, + tags AS b + WHERE a.tag_id=b.tag_id + AND b.tag_text=?; + """, + (tag,), + ) + model_keys = {x[0] for x in self._cursor.fetchall()} + if matches is None: + matches = model_keys + matches = matches.intersection(model_keys) + except sqlite3.Error as e: + raise e + return matches if matches else set() + + def search_by_author(self, author: str) -> Set[str]: + """Return the keys of models authored by the indicated author.""" + self._cursor.execute( + """--sql + SELECT id FROM model_metadata + WHERE author=?; + """, + (author,), + ) + return {x[0] for x in self._cursor.fetchall()} + + def search_by_name(self, name: str) -> Set[str]: + """ + Return the keys of models with the indicated name. + + Note that this is the name of the model given to it by + the remote source. The user may have changed the local + name. The local name will be located in the model config + record object. + """ + self._cursor.execute( + """--sql + SELECT id FROM model_metadata + WHERE name=?; + """, + (name,), + ) + return {x[0] for x in self._cursor.fetchall()} + + def _update_tags(self, model_key: str, tags: Set[str]) -> None: + """Update tags for the model referenced by model_key.""" + # remove previous tags from this model + self._cursor.execute( + """--sql + DELETE FROM model_tags + WHERE model_id=?; + """, + (model_key,), + ) + + for tag in tags: + self._cursor.execute( + """--sql + INSERT OR IGNORE INTO tags ( + tag_text + ) + VALUES (?); + """, + (tag,), + ) + self._cursor.execute( + """--sql + SELECT tag_id + FROM tags + WHERE tag_text = ? + LIMIT 1; + """, + (tag,), + ) + tag_id = self._cursor.fetchone()[0] + self._cursor.execute( + """--sql + INSERT OR IGNORE INTO model_tags ( + model_id, + tag_id + ) + VALUES (?,?); + """, + (model_key, tag_id), + ) diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index 57597570cde..d6014db448a 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -11,8 +11,15 @@ from pydantic import BaseModel, Field from invokeai.app.services.shared.pagination import PaginatedResults -from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType -from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore +from invokeai.backend.model_manager import ( + AnyModelConfig, + BaseModelType, + ModelFormat, + ModelType, +) +from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata + +from ..model_metadata import ModelMetadataStoreBase class DuplicateModelException(Exception): @@ -104,7 +111,7 @@ def get_model(self, key: str) -> AnyModelConfig: @property @abstractmethod - def metadata_store(self) -> ModelMetadataStore: + def metadata_store(self) -> ModelMetadataStoreBase: """Return a ModelMetadataStore initialized on the same database.""" pass @@ -146,7 +153,7 @@ def list_models( @abstractmethod def exists(self, key: str) -> bool: """ - Return True if a model with the indicated key exists in the databse. + Return True if a model with the indicated key exists in the database. :param key: Unique key for the model to be deleted """ diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index 4512da5d413..dcd1114655b 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -54,8 +54,9 @@ ModelFormat, ModelType, ) -from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException +from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException +from ..model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL from ..shared.sqlite.sqlite_database import SqliteDatabase from .model_records_base import ( DuplicateModelException, @@ -69,16 +70,16 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): """Implementation of the ModelConfigStore ABC using a SQL database.""" - def __init__(self, db: SqliteDatabase): + def __init__(self, db: SqliteDatabase, metadata_store: ModelMetadataStoreBase): """ Initialize a new object from preexisting sqlite3 connection and threading lock objects. - :param conn: sqlite3 connection object - :param lock: threading Lock object + :param db: Sqlite connection object """ super().__init__() self._db = db - self._cursor = self._db.conn.cursor() + self._cursor = db.conn.cursor() + self._metadata_store = metadata_store @property def db(self) -> SqliteDatabase: @@ -158,7 +159,7 @@ def del_model(self, key: str) -> None: self._db.conn.rollback() raise e - def update_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig: + def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig: """ Update the model, returning the updated version. @@ -199,7 +200,7 @@ def get_model(self, key: str) -> AnyModelConfig: with self._db.lock: self._cursor.execute( """--sql - SELECT config FROM model_config + SELECT config, strftime('%s',updated_at) FROM model_config WHERE id=?; """, (key,), @@ -207,7 +208,7 @@ def get_model(self, key: str) -> AnyModelConfig: rows = self._cursor.fetchone() if not rows: raise UnknownModelException("model not found") - model = ModelConfigFactory.make_config(json.loads(rows[0])) + model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1]) return model def exists(self, key: str) -> bool: @@ -265,12 +266,14 @@ def search_by_attr( with self._db.lock: self._cursor.execute( f"""--sql - select config FROM model_config + select config, strftime('%s',updated_at) FROM model_config {where}; """, tuple(bindings), ) - results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()] + results = [ + ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall() + ] return results def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]: @@ -279,12 +282,14 @@ def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]: with self._db.lock: self._cursor.execute( """--sql - SELECT config FROM model_config + SELECT config, strftime('%s',updated_at) FROM model_config WHERE path=?; """, (str(path),), ) - results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()] + results = [ + ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall() + ] return results def search_by_hash(self, hash: str) -> List[AnyModelConfig]: @@ -293,18 +298,20 @@ def search_by_hash(self, hash: str) -> List[AnyModelConfig]: with self._db.lock: self._cursor.execute( """--sql - SELECT config FROM model_config + SELECT config, strftime('%s',updated_at) FROM model_config WHERE original_hash=?; """, (hash,), ) - results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()] + results = [ + ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall() + ] return results @property - def metadata_store(self) -> ModelMetadataStore: + def metadata_store(self) -> ModelMetadataStoreBase: """Return a ModelMetadataStore initialized on the same database.""" - return ModelMetadataStore(self._db) + return self._metadata_store def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]: """ @@ -325,18 +332,18 @@ def search_by_metadata_tag(self, tags: Set[str]) -> List[AnyModelConfig]: :param tags: Set of tags to search for. All tags must be present. """ - store = ModelMetadataStore(self._db) + store = ModelMetadataStoreSQL(self._db) keys = store.search_by_tag(tags) return [self.get_model(x) for x in keys] def list_tags(self) -> Set[str]: """Return a unique set of all the model tags in the metadata database.""" - store = ModelMetadataStore(self._db) + store = ModelMetadataStoreSQL(self._db) return store.list_tags() def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: """List metadata for all models that have it.""" - store = ModelMetadataStore(self._db) + store = ModelMetadataStoreSQL(self._db) return store.list_all_metadata() def list_models( diff --git a/invokeai/app/services/shared/sqlite/sqlite_util.py b/invokeai/app/services/shared/sqlite/sqlite_util.py index 6079b3f08d7..681886eacd3 100644 --- a/invokeai/app/services/shared/sqlite/sqlite_util.py +++ b/invokeai/app/services/shared/sqlite/sqlite_util.py @@ -8,6 +8,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_3 import build_migration_3 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_5 import build_migration_5 +from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import build_migration_6 from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator @@ -33,6 +34,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto migrator.register_migration(build_migration_3(app_config=config, logger=logger)) migrator.register_migration(build_migration_4()) migrator.register_migration(build_migration_5()) + migrator.register_migration(build_migration_6()) migrator.run_migrations() return db diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_6.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_6.py new file mode 100644 index 00000000000..1f9ac56518c --- /dev/null +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_6.py @@ -0,0 +1,62 @@ +import sqlite3 + +from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration + + +class Migration6Callback: + def __call__(self, cursor: sqlite3.Cursor) -> None: + self._recreate_model_triggers(cursor) + self._delete_ip_adapters(cursor) + + def _recreate_model_triggers(self, cursor: sqlite3.Cursor) -> None: + """ + Adds the timestamp trigger to the model_config table. + + This trigger was inadvertently dropped in earlier migration scripts. + """ + + cursor.execute( + """--sql + CREATE TRIGGER IF NOT EXISTS model_config_updated_at + AFTER UPDATE + ON model_config FOR EACH ROW + BEGIN + UPDATE model_config SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW') + WHERE id = old.id; + END; + """ + ) + + def _delete_ip_adapters(self, cursor: sqlite3.Cursor) -> None: + """ + Delete all the IP adapters. + + The model manager will automatically find and re-add them after the migration + is done. This allows the manager to add the correct image encoder to their + configuration records. + """ + + cursor.execute( + """--sql + DELETE FROM model_config + WHERE type='ip_adapter'; + """ + ) + + +def build_migration_6() -> Migration: + """ + Build the migration from database version 5 to 6. + + This migration does the following: + - Adds the model_config_updated_at trigger if it does not exist + - Delete all ip_adapter models so that the model prober can find and + update with the correct image processor model. + """ + migration_6 = Migration( + from_version=5, + to_version=6, + callback=Migration6Callback(), + ) + + return migration_6 diff --git a/invokeai/app/util/misc.py b/invokeai/app/util/misc.py index 910b05d8dde..da431929dbe 100644 --- a/invokeai/app/util/misc.py +++ b/invokeai/app/util/misc.py @@ -5,7 +5,7 @@ import numpy as np -def get_timestamp(): +def get_timestamp() -> int: return int(datetime.datetime.now(datetime.timezone.utc).timestamp()) @@ -20,16 +20,16 @@ def get_datetime_from_iso_timestamp(iso_timestamp: str) -> datetime.datetime: SEED_MAX = np.iinfo(np.uint32).max -def get_random_seed(): +def get_random_seed() -> int: rng = np.random.default_rng(seed=None) return int(rng.integers(0, SEED_MAX)) -def uuid_string(): +def uuid_string() -> str: res = uuid.uuid4() return str(res) -def is_optional(value: typing.Any): +def is_optional(value: typing.Any) -> bool: """Checks if a value is typed as Optional. Note that Optional is sugar for Union[x, None].""" return typing.get_origin(value) is typing.Union and type(None) in typing.get_args(value) diff --git a/invokeai/app/util/step_callback.py b/invokeai/app/util/step_callback.py index f166206d528..f2b5643aa31 100644 --- a/invokeai/app/util/step_callback.py +++ b/invokeai/app/util/step_callback.py @@ -3,7 +3,7 @@ from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException, ProgressImage -from ...backend.model_management.models import BaseModelType +from ...backend.model_manager import BaseModelType from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.util.util import image_to_dataURL from ..invocations.baseinvocation import InvocationContext diff --git a/invokeai/backend/__init__.py b/invokeai/backend/__init__.py index ae9a12edbe2..9fe97ee525e 100644 --- a/invokeai/backend/__init__.py +++ b/invokeai/backend/__init__.py @@ -1,5 +1,3 @@ """ Initialization file for invokeai.backend """ -from .model_management import BaseModelType, ModelCache, ModelInfo, ModelManager, ModelType, SubModelType # noqa: F401 -from .model_management.models import SilenceWarnings # noqa: F401 diff --git a/invokeai/backend/embeddings/__init__.py b/invokeai/backend/embeddings/__init__.py new file mode 100644 index 00000000000..46ead533c4d --- /dev/null +++ b/invokeai/backend/embeddings/__init__.py @@ -0,0 +1,4 @@ +"""Initialization file for invokeai.backend.embeddings modules.""" + +# from .model_patcher import ModelPatcher +# __all__ = ["ModelPatcher"] diff --git a/invokeai/backend/embeddings/embedding_base.py b/invokeai/backend/embeddings/embedding_base.py new file mode 100644 index 00000000000..5e752a29e14 --- /dev/null +++ b/invokeai/backend/embeddings/embedding_base.py @@ -0,0 +1,12 @@ +"""Base class for LoRA and Textual Inversion models. + +The EmbeddingRaw class is the base class of LoRAModelRaw and TextualInversionModelRaw, +and is used for type checking of calls to the model patcher. + +The use of "Raw" here is a historical artifact, and carried forward in +order to avoid confusion. +""" + + +class EmbeddingModelRaw: + """Base class for LoRA and Textual Inversion models.""" diff --git a/invokeai/backend/image_util/safety_checker.py b/invokeai/backend/image_util/safety_checker.py index b9649925e14..92ddef5ecc3 100644 --- a/invokeai/backend/image_util/safety_checker.py +++ b/invokeai/backend/image_util/safety_checker.py @@ -8,8 +8,8 @@ import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend import SilenceWarnings from invokeai.backend.util.devices import choose_torch_device +from invokeai.backend.util.silence_warnings import SilenceWarnings config = InvokeAIAppConfig.get_config() diff --git a/invokeai/backend/install/install_helper.py b/invokeai/backend/install/install_helper.py index e54be527d95..3623b623a94 100644 --- a/invokeai/backend/install/install_helper.py +++ b/invokeai/backend/install/install_helper.py @@ -25,18 +25,20 @@ ModelSource, URLModelSource, ) +from invokeai.app.services.model_metadata import ModelMetadataStoreSQL from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL from invokeai.app.services.shared.sqlite.sqlite_util import init_db from invokeai.backend.model_manager import ( BaseModelType, InvalidModelConfigException, + ModelRepoVariant, ModelType, ) from invokeai.backend.model_manager.metadata import UnknownMetadataException from invokeai.backend.util.logging import InvokeAILogger # name of the starter models file -INITIAL_MODELS = "INITIAL_MODELS2.yaml" +INITIAL_MODELS = "INITIAL_MODELS.yaml" def initialize_record_store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase: @@ -44,7 +46,7 @@ def initialize_record_store(app_config: InvokeAIAppConfig) -> ModelRecordService logger = InvokeAILogger.get_logger(config=app_config) image_files = DiskImageFileStorage(f"{app_config.output_path}/images") db = init_db(config=app_config, logger=logger, image_files=image_files) - obj: ModelRecordServiceBase = ModelRecordServiceSQL(db) + obj: ModelRecordServiceBase = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db)) return obj @@ -53,12 +55,10 @@ def initialize_installer( ) -> ModelInstallServiceBase: """Return an initialized ModelInstallService object.""" record_store = initialize_record_store(app_config) - metadata_store = record_store.metadata_store download_queue = DownloadQueueService() installer = ModelInstallService( app_config=app_config, record_store=record_store, - metadata_store=metadata_store, download_queue=download_queue, event_bus=event_bus, ) @@ -98,11 +98,13 @@ def __init__(self) -> None: super().__init__() self._bars: Dict[str, tqdm] = {} self._last: Dict[str, int] = {} + self._logger = InvokeAILogger.get_logger(__name__) def dispatch(self, event_name: str, payload: Any) -> None: """Dispatch an event by appending it to self.events.""" + data = payload["data"] + source = data["source"] if payload["event"] == "model_install_downloading": - data = payload["data"] dest = data["local_path"] total_bytes = data["total_bytes"] bytes = data["bytes"] @@ -111,6 +113,12 @@ def dispatch(self, event_name: str, payload: Any) -> None: self._last[dest] = 0 self._bars[dest].update(bytes - self._last[dest]) self._last[dest] = bytes + elif payload["event"] == "model_install_completed": + self._logger.info(f"{source}: installed successfully.") + elif payload["event"] == "model_install_error": + self._logger.warning(f"{source}: installation failed with error {data['error']}") + elif payload["event"] == "model_install_cancelled": + self._logger.warning(f"{source}: installation cancelled") class InstallHelper(object): @@ -225,11 +233,19 @@ def _make_install_source(self, model_info: UnifiedModelInfo) -> ModelSource: if model_path.exists(): # local file on disk return LocalModelSource(path=model_path.absolute(), inplace=True) - if re.match(r"^[^/]+/[^/]+$", model_path_id_or_url): # hugging face repo_id + + # parsing huggingface repo ids + # we're going to do a little trick that allows for extended repo_ids of form "foo/bar:fp16" + variants = "|".join([x.lower() for x in ModelRepoVariant.__members__]) + if match := re.match(f"^([^/]+/[^/]+?)(?::({variants}))?$", model_path_id_or_url): + repo_id = match.group(1) + repo_variant = ModelRepoVariant(match.group(2)) if match.group(2) else None + subfolder = Path(model_info.subfolder) if model_info.subfolder else None return HFModelSource( - repo_id=model_path_id_or_url, + repo_id=repo_id, access_token=HfFolder.get_token(), - subfolder=model_info.subfolder, + subfolder=subfolder, + variant=repo_variant, ) if re.match(r"^(http|https):", model_path_id_or_url): return URLModelSource(url=AnyHttpUrl(model_path_id_or_url)) @@ -270,12 +286,14 @@ def add_or_delete(self, selections: InstallSelections) -> None: model_name=model_name, ) if len(matches) > 1: - print(f"{model} is ambiguous. Please use model_type:model_name (e.g. main:my_model) to disambiguate.") + self._logger.error( + "{model_to_remove} is ambiguous. Please use model_base/model_type/model_name (e.g. sd-1/main/my_model) to disambiguate" + ) elif not matches: - print(f"{model}: unknown model") + self._logger.error(f"{model_to_remove}: unknown model") else: for m in matches: - print(f"Deleting {m.type}:{m.name}") + self._logger.info(f"Deleting {m.type}:{m.name}") installer.delete(m.key) installer.wait_for_installs() diff --git a/invokeai/backend/install/invokeai_configure.py b/invokeai/backend/install/invokeai_configure.py index 3cb7db6c82c..4dfa2b070c0 100755 --- a/invokeai/backend/install/invokeai_configure.py +++ b/invokeai/backend/install/invokeai_configure.py @@ -18,31 +18,30 @@ from enum import Enum from pathlib import Path from shutil import get_terminal_size -from typing import Any, get_args, get_type_hints +from typing import Any, Optional, Set, Tuple, Type, get_args, get_type_hints from urllib import request import npyscreen -import omegaconf import psutil import torch import transformers -import yaml -from diffusers import AutoencoderKL +from diffusers import AutoencoderKL, ModelMixin from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from huggingface_hub import HfFolder from huggingface_hub import login as hf_hub_login -from omegaconf import OmegaConf -from pydantic import ValidationError +from omegaconf import DictConfig, OmegaConf +from pydantic.error_wrappers import ValidationError from tqdm import tqdm from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextConfig, CLIPTextModel, CLIPTokenizer import invokeai.configs as configs from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend.install.install_helper import InstallHelper, InstallSelections from invokeai.backend.install.legacy_arg_parsing import legacy_parser -from invokeai.backend.install.model_install_backend import InstallSelections, ModelInstall, hf_download_from_pretrained -from invokeai.backend.model_management.model_probe import BaseModelType, ModelType +from invokeai.backend.model_manager import BaseModelType, ModelType +from invokeai.backend.util import choose_precision, choose_torch_device from invokeai.backend.util.logging import InvokeAILogger -from invokeai.frontend.install.model_install import addModelsForm, process_and_execute +from invokeai.frontend.install.model_install import addModelsForm # TO DO - Move all the frontend code into invokeai.frontend.install from invokeai.frontend.install.widgets import ( @@ -61,7 +60,7 @@ transformers.logging.set_verbosity_error() -def get_literal_fields(field) -> list[Any]: +def get_literal_fields(field: str) -> Tuple[Any]: return get_args(get_type_hints(InvokeAIAppConfig).get(field)) @@ -80,8 +79,7 @@ def get_literal_fields(field) -> list[Any]: GENERATION_OPT_CHOICES = ["sequential_guidance", "force_tiled_decode", "lazy_offload"] GB = 1073741824 # GB in bytes HAS_CUDA = torch.cuda.is_available() -_, MAX_VRAM = torch.cuda.mem_get_info() if HAS_CUDA else (0, 0) - +_, MAX_VRAM = torch.cuda.mem_get_info() if HAS_CUDA else (0.0, 0.0) MAX_VRAM /= GB MAX_RAM = psutil.virtual_memory().total / GB @@ -96,13 +94,15 @@ def get_literal_fields(field) -> list[Any]: class DummyWidgetValue(Enum): + """Dummy widget values.""" + zero = 0 true = True false = False # -------------------------------------------- -def postscript(errors: None): +def postscript(errors: Set[str]) -> None: if not any(errors): message = f""" ** INVOKEAI INSTALLATION SUCCESSFUL ** @@ -143,7 +143,7 @@ def yes_or_no(prompt: str, default_yes=True): # --------------------------------------------- -def HfLogin(access_token) -> str: +def HfLogin(access_token) -> None: """ Helper for logging in to Huggingface The stdout capture is needed to hide the irrelevant "git credential helper" warning @@ -162,7 +162,7 @@ def HfLogin(access_token) -> str: # ------------------------------------- class ProgressBar: - def __init__(self, model_name="file"): + def __init__(self, model_name: str = "file"): self.pbar = None self.name = model_name @@ -179,6 +179,22 @@ def __call__(self, block_num, block_size, total_size): self.pbar.update(block_size) +# --------------------------------------------- +def hf_download_from_pretrained(model_class: Type[ModelMixin], model_name: str, destination: Path, **kwargs: Any): + filter = lambda x: "fp16 is not a valid" not in x.getMessage() # noqa E731 + logger.addFilter(filter) + try: + model = model_class.from_pretrained( + model_name, + resume_download=True, + **kwargs, + ) + model.save_pretrained(destination, safe_serialization=True) + finally: + logger.removeFilter(filter) + return destination + + # --------------------------------------------- def download_with_progress_bar(model_url: str, model_dest: str, label: str = "the"): try: @@ -249,6 +265,7 @@ def download_conversion_models(): # --------------------------------------------- +# TO DO: use the download queue here. def download_realesrgan(): logger.info("Installing ESRGAN Upscaling models...") URLs = [ @@ -288,18 +305,19 @@ def download_lama(): # --------------------------------------------- -def download_support_models(): +def download_support_models() -> None: download_realesrgan() download_lama() download_conversion_models() # ------------------------------------- -def get_root(root: str = None) -> str: +def get_root(root: Optional[str] = None) -> str: if root: return root - elif os.environ.get("INVOKEAI_ROOT"): - return os.environ.get("INVOKEAI_ROOT") + elif root := os.environ.get("INVOKEAI_ROOT"): + assert root is not None + return root else: return str(config.root_path) @@ -455,6 +473,25 @@ def create(self): max_width=110, scroll_exit=True, ) + self.add_widget_intelligent( + npyscreen.TitleFixedText, + name="Model disk conversion cache size (GB). This is used to cache safetensors files that need to be converted to diffusers..", + begin_entry_at=0, + editable=False, + color="CONTROL", + scroll_exit=True, + ) + self.nextrely -= 1 + self.disk = self.add_widget_intelligent( + npyscreen.Slider, + value=clip(old_opts.convert_cache, range=(0, 100), step=0.5), + out_of=100, + lowest=0.0, + step=0.5, + relx=8, + scroll_exit=True, + ) + self.nextrely += 1 self.add_widget_intelligent( npyscreen.TitleFixedText, name="Model RAM cache size (GB). Make this at least large enough to hold a single full model (2GB for SD-1, 6GB for SDXL).", @@ -495,6 +532,14 @@ def create(self): ) else: self.vram = DummyWidgetValue.zero + + self.nextrely += 1 + self.add_widget_intelligent( + npyscreen.FixedText, + value="Location of the database used to store model path and configuration information:", + editable=False, + color="CONTROL", + ) self.nextrely += 1 self.outdir = self.add_widget_intelligent( FileBox, @@ -506,19 +551,21 @@ def create(self): labelColor="GOOD", begin_entry_at=40, max_height=3, + max_width=127, scroll_exit=True, ) self.autoimport_dirs = {} self.autoimport_dirs["autoimport_dir"] = self.add_widget_intelligent( FileBox, - name="Folder to recursively scan for new checkpoints, ControlNets, LoRAs and TI models", - value=str(config.root_path / config.autoimport_dir), + name="Optional folder to scan for new checkpoints, ControlNets, LoRAs and TI models", + value=str(config.root_path / config.autoimport_dir) if config.autoimport_dir else "", select_dir=True, must_exist=False, use_two_lines=False, labelColor="GOOD", begin_entry_at=32, max_height=3, + max_width=127, scroll_exit=True, ) self.nextrely += 1 @@ -555,6 +602,10 @@ def show_hide_slice_sizes(self, value): self.attention_slice_label.hidden = not show self.attention_slice_size.hidden = not show + def show_hide_model_conf_override(self, value): + self.model_conf_override.hidden = value + self.model_conf_override.display() + def on_ok(self): options = self.marshall_arguments() if self.validate_field_values(options): @@ -584,18 +635,21 @@ def validate_field_values(self, opt: Namespace) -> bool: else: return True - def marshall_arguments(self): + def marshall_arguments(self) -> Namespace: new_opts = Namespace() for attr in [ "ram", "vram", + "convert_cache", "outdir", ]: if hasattr(self, attr): setattr(new_opts, attr, getattr(self, attr).value) for attr in self.autoimport_dirs: + if not self.autoimport_dirs[attr].value: + continue directory = Path(self.autoimport_dirs[attr].value) if directory.is_relative_to(config.root_path): directory = directory.relative_to(config.root_path) @@ -615,13 +669,14 @@ def marshall_arguments(self): class EditOptApplication(npyscreen.NPSAppManaged): - def __init__(self, program_opts: Namespace, invokeai_opts: Namespace): + def __init__(self, program_opts: Namespace, invokeai_opts: InvokeAIAppConfig, install_helper: InstallHelper): super().__init__() self.program_opts = program_opts self.invokeai_opts = invokeai_opts self.user_cancelled = False self.autoload_pending = True - self.install_selections = default_user_selections(program_opts) + self.install_helper = install_helper + self.install_selections = default_user_selections(program_opts, install_helper) def onStart(self): npyscreen.setTheme(npyscreen.Themes.DefaultTheme) @@ -640,16 +695,10 @@ def onStart(self): cycle_widgets=False, ) - def new_opts(self): + def new_opts(self) -> Namespace: return self.options.marshall_arguments() -def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Namespace: - editApp = EditOptApplication(program_opts, invokeai_opts) - editApp.run() - return editApp.new_opts() - - def default_ramcache() -> float: """Run a heuristic for the default RAM cache based on installed RAM.""" @@ -660,27 +709,18 @@ def default_ramcache() -> float: ) # 2.1 is just large enough for sd 1.5 ;-) -def default_startup_options(init_file: Path) -> Namespace: +def default_startup_options(init_file: Path) -> InvokeAIAppConfig: opts = InvokeAIAppConfig.get_config() - opts.ram = opts.ram or default_ramcache() + opts.ram = default_ramcache() return opts -def default_user_selections(program_opts: Namespace) -> InstallSelections: - try: - installer = ModelInstall(config) - except omegaconf.errors.ConfigKeyError: - logger.warning("Your models.yaml file is corrupt or out of date. Reinitializing") - initialize_rootdir(config.root_path, True) - installer = ModelInstall(config) - - models = installer.all_models() +def default_user_selections(program_opts: Namespace, install_helper: InstallHelper) -> InstallSelections: + default_model = install_helper.default_model() + assert default_model is not None + default_models = [default_model] if program_opts.default_only else install_helper.recommended_models() return InstallSelections( - install_models=[models[installer.default_model()].path or models[installer.default_model()].repo_id] - if program_opts.default_only - else [models[x].path or models[x].repo_id for x in installer.recommended_models()] - if program_opts.yes_to_all - else [], + install_models=default_models if program_opts.yes_to_all else [], ) @@ -716,21 +756,10 @@ def initialize_rootdir(root: Path, yes_to_all: bool = False): path.mkdir(parents=True, exist_ok=True) -def maybe_create_models_yaml(root: Path): - models_yaml = root / "configs" / "models.yaml" - if models_yaml.exists(): - if OmegaConf.load(models_yaml).get("__metadata__"): # up to date - return - else: - logger.info("Creating new models.yaml, original saved as models.yaml.orig") - models_yaml.rename(models_yaml.parent / "models.yaml.orig") - - with open(models_yaml, "w") as yaml_file: - yaml_file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}})) - - # ------------------------------------- -def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace, Namespace): +def run_console_ui( + program_opts: Namespace, initfile: Path, install_helper: InstallHelper +) -> Tuple[Optional[Namespace], Optional[InstallSelections]]: invokeai_opts = default_startup_options(initfile) invokeai_opts.root = program_opts.root @@ -739,22 +768,16 @@ def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace "Could not increase terminal size. Try running again with a larger window or smaller font size." ) - # the install-models application spawns a subprocess to install - # models, and will crash unless this is set before running. - import torch - - torch.multiprocessing.set_start_method("spawn") - - editApp = EditOptApplication(program_opts, invokeai_opts) + editApp = EditOptApplication(program_opts, invokeai_opts, install_helper) editApp.run() if editApp.user_cancelled: return (None, None) else: - return (editApp.new_opts, editApp.install_selections) + return (editApp.new_opts(), editApp.install_selections) # ------------------------------------- -def write_opts(opts: Namespace, init_file: Path): +def write_opts(opts: InvokeAIAppConfig, init_file: Path) -> None: """ Update the invokeai.yaml file with values from current settings. """ @@ -762,7 +785,7 @@ def write_opts(opts: Namespace, init_file: Path): new_config = InvokeAIAppConfig.get_config() new_config.root = config.root - for key, value in opts.__dict__.items(): + for key, value in opts.model_dump().items(): if hasattr(new_config, key): setattr(new_config, key, value) @@ -779,7 +802,7 @@ def default_output_dir() -> Path: # ------------------------------------- -def write_default_options(program_opts: Namespace, initfile: Path): +def write_default_options(program_opts: Namespace, initfile: Path) -> None: opt = default_startup_options(initfile) write_opts(opt, initfile) @@ -789,16 +812,11 @@ def write_default_options(program_opts: Namespace, initfile: Path): # the legacy Args object in order to parse # the old init file and write out the new # yaml format. -def migrate_init_file(legacy_format: Path): +def migrate_init_file(legacy_format: Path) -> None: old = legacy_parser.parse_args([f"@{str(legacy_format)}"]) new = InvokeAIAppConfig.get_config() - fields = [ - x - for x, y in InvokeAIAppConfig.model_fields.items() - if (y.json_schema_extra.get("category", None) if y.json_schema_extra else None) != "DEPRECATED" - ] - for attr in fields: + for attr in InvokeAIAppConfig.model_fields.keys(): if hasattr(old, attr): try: setattr(new, attr, getattr(old, attr)) @@ -819,7 +837,7 @@ def migrate_init_file(legacy_format: Path): # ------------------------------------- -def migrate_models(root: Path): +def migrate_models(root: Path) -> None: from invokeai.backend.install.migrate_to_3 import do_migrate do_migrate(root, root) @@ -838,7 +856,9 @@ def migrate_if_needed(opt: Namespace, root: Path) -> bool: ): logger.info("** Migrating invokeai.init to invokeai.yaml") migrate_init_file(old_init_file) - config.parse_args(argv=[], conf=OmegaConf.load(new_init_file)) + omegaconf = OmegaConf.load(new_init_file) + assert isinstance(omegaconf, DictConfig) + config.parse_args(argv=[], conf=omegaconf) if old_hub.exists(): migrate_models(config.root_path) @@ -849,7 +869,7 @@ def migrate_if_needed(opt: Namespace, root: Path) -> bool: # ------------------------------------- -def main() -> None: +def main(): parser = argparse.ArgumentParser(description="InvokeAI model downloader") parser.add_argument( "--skip-sd-weights", @@ -908,6 +928,7 @@ def main() -> None: if opt.full_precision: invoke_args.extend(["--precision", "float32"]) config.parse_args(invoke_args) + config.precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device())) logger = InvokeAILogger().get_logger(config=config) errors = set() @@ -921,14 +942,18 @@ def main() -> None: # run this unconditionally in case new directories need to be added initialize_rootdir(config.root_path, opt.yes_to_all) - models_to_download = default_user_selections(opt) + # this will initialize the models.yaml file if not present + install_helper = InstallHelper(config, logger) + + models_to_download = default_user_selections(opt, install_helper) new_init_file = config.root_path / "invokeai.yaml" if opt.yes_to_all: write_default_options(opt, new_init_file) init_options = Namespace(precision="float32" if opt.full_precision else "float16") + else: - init_options, models_to_download = run_console_ui(opt, new_init_file) + init_options, models_to_download = run_console_ui(opt, new_init_file, install_helper) if init_options: write_opts(init_options, new_init_file) else: @@ -943,10 +968,12 @@ def main() -> None: if opt.skip_sd_weights: logger.warning("Skipping diffusion weights download per user request") + elif models_to_download: - process_and_execute(opt, models_to_download) + install_helper.add_or_delete(models_to_download) postscript(errors=errors) + if not opt.yes_to_all: input("Press any key to continue...") except WindowTooSmallException as e: diff --git a/invokeai/backend/install/migrate_to_3.py b/invokeai/backend/install/migrate_to_3.py deleted file mode 100644 index e15eb23f5b2..00000000000 --- a/invokeai/backend/install/migrate_to_3.py +++ /dev/null @@ -1,591 +0,0 @@ -""" -Migrate the models directory and models.yaml file from an existing -InvokeAI 2.3 installation to 3.0.0. -""" - -import argparse -import os -import shutil -import warnings -from dataclasses import dataclass -from pathlib import Path -from typing import Union - -import diffusers -import transformers -import yaml -from diffusers import AutoencoderKL, StableDiffusionPipeline -from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker -from omegaconf import DictConfig, OmegaConf -from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer - -import invokeai.backend.util.logging as logger -from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_management import ModelManager -from invokeai.backend.model_management.model_probe import BaseModelType, ModelProbe, ModelProbeInfo, ModelType - -warnings.filterwarnings("ignore") -transformers.logging.set_verbosity_error() -diffusers.logging.set_verbosity_error() - - -# holder for paths that we will migrate -@dataclass -class ModelPaths: - models: Path - embeddings: Path - loras: Path - controlnets: Path - - -class MigrateTo3(object): - def __init__( - self, - from_root: Path, - to_models: Path, - model_manager: ModelManager, - src_paths: ModelPaths, - ): - self.root_directory = from_root - self.dest_models = to_models - self.mgr = model_manager - self.src_paths = src_paths - - @classmethod - def initialize_yaml(cls, yaml_file: Path): - with open(yaml_file, "w") as file: - file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}})) - - def create_directory_structure(self): - """ - Create the basic directory structure for the models folder. - """ - for model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]: - for model_type in [ - ModelType.Main, - ModelType.Vae, - ModelType.Lora, - ModelType.ControlNet, - ModelType.TextualInversion, - ]: - path = self.dest_models / model_base.value / model_type.value - path.mkdir(parents=True, exist_ok=True) - path = self.dest_models / "core" - path.mkdir(parents=True, exist_ok=True) - - @staticmethod - def copy_file(src: Path, dest: Path): - """ - copy a single file with logging - """ - if dest.exists(): - logger.info(f"Skipping existing {str(dest)}") - return - logger.info(f"Copying {str(src)} to {str(dest)}") - try: - shutil.copy(src, dest) - except Exception as e: - logger.error(f"COPY FAILED: {str(e)}") - - @staticmethod - def copy_dir(src: Path, dest: Path): - """ - Recursively copy a directory with logging - """ - if dest.exists(): - logger.info(f"Skipping existing {str(dest)}") - return - - logger.info(f"Copying {str(src)} to {str(dest)}") - try: - shutil.copytree(src, dest) - except Exception as e: - logger.error(f"COPY FAILED: {str(e)}") - - def migrate_models(self, src_dir: Path): - """ - Recursively walk through src directory, probe anything - that looks like a model, and copy the model into the - appropriate location within the destination models directory. - """ - directories_scanned = set() - for root, dirs, files in os.walk(src_dir, followlinks=True): - for d in dirs: - try: - model = Path(root, d) - info = ModelProbe().heuristic_probe(model) - if not info: - continue - dest = self._model_probe_to_path(info) / model.name - self.copy_dir(model, dest) - directories_scanned.add(model) - except Exception as e: - logger.error(str(e)) - except KeyboardInterrupt: - raise - for f in files: - # don't copy raw learned_embeds.bin or pytorch_lora_weights.bin - # let them be copied as part of a tree copy operation - try: - if f in {"learned_embeds.bin", "pytorch_lora_weights.bin"}: - continue - model = Path(root, f) - if model.parent in directories_scanned: - continue - info = ModelProbe().heuristic_probe(model) - if not info: - continue - dest = self._model_probe_to_path(info) / f - self.copy_file(model, dest) - except Exception as e: - logger.error(str(e)) - except KeyboardInterrupt: - raise - - def migrate_support_models(self): - """ - Copy the clipseg, upscaler, and restoration models to their new - locations. - """ - dest_directory = self.dest_models - if (self.root_directory / "models/clipseg").exists(): - self.copy_dir(self.root_directory / "models/clipseg", dest_directory / "core/misc/clipseg") - if (self.root_directory / "models/realesrgan").exists(): - self.copy_dir(self.root_directory / "models/realesrgan", dest_directory / "core/upscaling/realesrgan") - for d in ["codeformer", "gfpgan"]: - path = self.root_directory / "models" / d - if path.exists(): - self.copy_dir(path, dest_directory / f"core/face_restoration/{d}") - - def migrate_tuning_models(self): - """ - Migrate the embeddings, loras and controlnets directories to their new homes. - """ - for src in [self.src_paths.embeddings, self.src_paths.loras, self.src_paths.controlnets]: - if not src: - continue - if src.is_dir(): - logger.info(f"Scanning {src}") - self.migrate_models(src) - else: - logger.info(f"{src} directory not found; skipping") - continue - - def migrate_conversion_models(self): - """ - Migrate all the models that are needed by the ckpt_to_diffusers conversion - script. - """ - - dest_directory = self.dest_models - kwargs = { - "cache_dir": self.root_directory / "models/hub", - # local_files_only = True - } - try: - logger.info("Migrating core tokenizers and text encoders") - target_dir = dest_directory / "core" / "convert" - - self._migrate_pretrained( - BertTokenizerFast, repo_id="bert-base-uncased", dest=target_dir / "bert-base-uncased", **kwargs - ) - - # sd-1 - repo_id = "openai/clip-vit-large-patch14" - self._migrate_pretrained( - CLIPTokenizer, repo_id=repo_id, dest=target_dir / "clip-vit-large-patch14", **kwargs - ) - self._migrate_pretrained( - CLIPTextModel, repo_id=repo_id, dest=target_dir / "clip-vit-large-patch14", force=True, **kwargs - ) - - # sd-2 - repo_id = "stabilityai/stable-diffusion-2" - self._migrate_pretrained( - CLIPTokenizer, - repo_id=repo_id, - dest=target_dir / "stable-diffusion-2-clip" / "tokenizer", - **{"subfolder": "tokenizer", **kwargs}, - ) - self._migrate_pretrained( - CLIPTextModel, - repo_id=repo_id, - dest=target_dir / "stable-diffusion-2-clip" / "text_encoder", - **{"subfolder": "text_encoder", **kwargs}, - ) - - # VAE - logger.info("Migrating stable diffusion VAE") - self._migrate_pretrained( - AutoencoderKL, repo_id="stabilityai/sd-vae-ft-mse", dest=target_dir / "sd-vae-ft-mse", **kwargs - ) - - # safety checking - logger.info("Migrating safety checker") - repo_id = "CompVis/stable-diffusion-safety-checker" - self._migrate_pretrained( - AutoFeatureExtractor, repo_id=repo_id, dest=target_dir / "stable-diffusion-safety-checker", **kwargs - ) - self._migrate_pretrained( - StableDiffusionSafetyChecker, - repo_id=repo_id, - dest=target_dir / "stable-diffusion-safety-checker", - **kwargs, - ) - except KeyboardInterrupt: - raise - except Exception as e: - logger.error(str(e)) - - def _model_probe_to_path(self, info: ModelProbeInfo) -> Path: - return Path(self.dest_models, info.base_type.value, info.model_type.value) - - def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, force: bool = False, **kwargs): - if dest.exists() and not force: - logger.info(f"Skipping existing {dest}") - return - model = model_class.from_pretrained(repo_id, **kwargs) - self._save_pretrained(model, dest, overwrite=force) - - def _save_pretrained(self, model, dest: Path, overwrite: bool = False): - model_name = dest.name - if overwrite: - model.save_pretrained(dest, safe_serialization=True) - else: - download_path = dest.with_name(f"{model_name}.downloading") - model.save_pretrained(download_path, safe_serialization=True) - download_path.replace(dest) - - def _download_vae(self, repo_id: str, subfolder: str = None) -> Path: - vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / "models/hub", subfolder=subfolder) - info = ModelProbe().heuristic_probe(vae) - _, model_name = repo_id.split("/") - dest = self._model_probe_to_path(info) / self.unique_name(model_name, info) - vae.save_pretrained(dest, safe_serialization=True) - return dest - - def _vae_path(self, vae: Union[str, dict]) -> Path: - """ - Convert 2.3 VAE stanza to a straight path. - """ - vae_path = None - - # First get a path - if isinstance(vae, str): - vae_path = vae - - elif isinstance(vae, DictConfig): - if p := vae.get("path"): - vae_path = p - elif repo_id := vae.get("repo_id"): - if repo_id == "stabilityai/sd-vae-ft-mse": # this guy is already downloaded - vae_path = "models/core/convert/sd-vae-ft-mse" - return vae_path - else: - vae_path = self._download_vae(repo_id, vae.get("subfolder")) - - assert vae_path is not None, "Couldn't find VAE for this model" - - # if the VAE is in the old models directory, then we must move it into the new - # one. VAEs outside of this directory can stay where they are. - vae_path = Path(vae_path) - if vae_path.is_relative_to(self.src_paths.models): - info = ModelProbe().heuristic_probe(vae_path) - dest = self._model_probe_to_path(info) / vae_path.name - if not dest.exists(): - if vae_path.is_dir(): - self.copy_dir(vae_path, dest) - else: - self.copy_file(vae_path, dest) - vae_path = dest - - if vae_path.is_relative_to(self.dest_models): - rel_path = vae_path.relative_to(self.dest_models) - return Path("models", rel_path) - else: - return vae_path - - def migrate_repo_id(self, repo_id: str, model_name: str = None, **extra_config): - """ - Migrate a locally-cached diffusers pipeline identified with a repo_id - """ - dest_dir = self.dest_models - - cache = self.root_directory / "models/hub" - kwargs = { - "cache_dir": cache, - "safety_checker": None, - # local_files_only = True, - } - - owner, repo_name = repo_id.split("/") - model_name = model_name or repo_name - model = cache / "--".join(["models", owner, repo_name]) - - if len(list(model.glob("snapshots/**/model_index.json"))) == 0: - return - revisions = [x.name for x in model.glob("refs/*")] - - # if an fp16 is available we use that - revision = "fp16" if len(revisions) > 1 and "fp16" in revisions else revisions[0] - pipeline = StableDiffusionPipeline.from_pretrained(repo_id, revision=revision, **kwargs) - - info = ModelProbe().heuristic_probe(pipeline) - if not info: - return - - if self.mgr.model_exists(model_name, info.base_type, info.model_type): - logger.warning(f"A model named {model_name} already exists at the destination. Skipping migration.") - return - - dest = self._model_probe_to_path(info) / model_name - self._save_pretrained(pipeline, dest) - - rel_path = Path("models", dest.relative_to(dest_dir)) - self._add_model(model_name, info, rel_path, **extra_config) - - def migrate_path(self, location: Path, model_name: str = None, **extra_config): - """ - Migrate a model referred to using 'weights' or 'path' - """ - - # handle relative paths - dest_dir = self.dest_models - location = self.root_directory / location - model_name = model_name or location.stem - - info = ModelProbe().heuristic_probe(location) - if not info: - return - - if self.mgr.model_exists(model_name, info.base_type, info.model_type): - logger.warning(f"A model named {model_name} already exists at the destination. Skipping migration.") - return - - # uh oh, weights is in the old models directory - move it into the new one - if Path(location).is_relative_to(self.src_paths.models): - dest = Path(dest_dir, info.base_type.value, info.model_type.value, location.name) - if location.is_dir(): - self.copy_dir(location, dest) - else: - self.copy_file(location, dest) - location = Path("models", info.base_type.value, info.model_type.value, location.name) - - self._add_model(model_name, info, location, **extra_config) - - def _add_model(self, model_name: str, info: ModelProbeInfo, location: Path, **extra_config): - if info.model_type != ModelType.Main: - return - - self.mgr.add_model( - model_name=model_name, - base_model=info.base_type, - model_type=info.model_type, - clobber=True, - model_attributes={ - "path": str(location), - "description": f"A {info.base_type.value} {info.model_type.value} model", - "model_format": info.format, - "variant": info.variant_type.value, - **extra_config, - }, - ) - - def migrate_defined_models(self): - """ - Migrate models defined in models.yaml - """ - # find any models referred to in old models.yaml - conf = OmegaConf.load(self.root_directory / "configs/models.yaml") - - for model_name, stanza in conf.items(): - try: - passthru_args = {} - - if vae := stanza.get("vae"): - try: - passthru_args["vae"] = str(self._vae_path(vae)) - except Exception as e: - logger.warning(f'Could not find a VAE matching "{vae}" for model "{model_name}"') - logger.warning(str(e)) - - if config := stanza.get("config"): - passthru_args["config"] = config - - if description := stanza.get("description"): - passthru_args["description"] = description - - if repo_id := stanza.get("repo_id"): - logger.info(f"Migrating diffusers model {model_name}") - self.migrate_repo_id(repo_id, model_name, **passthru_args) - - elif location := stanza.get("weights"): - logger.info(f"Migrating checkpoint model {model_name}") - self.migrate_path(Path(location), model_name, **passthru_args) - - elif location := stanza.get("path"): - logger.info(f"Migrating diffusers model {model_name}") - self.migrate_path(Path(location), model_name, **passthru_args) - - except KeyboardInterrupt: - raise - except Exception as e: - logger.error(str(e)) - - def migrate(self): - self.create_directory_structure() - # the configure script is doing this - self.migrate_support_models() - self.migrate_conversion_models() - self.migrate_tuning_models() - self.migrate_defined_models() - - -def _parse_legacy_initfile(root: Path, initfile: Path) -> ModelPaths: - """ - Returns tuple of (embedding_path, lora_path, controlnet_path) - """ - parser = argparse.ArgumentParser(fromfile_prefix_chars="@") - parser.add_argument( - "--embedding_directory", - "--embedding_path", - type=Path, - dest="embedding_path", - default=Path("embeddings"), - ) - parser.add_argument( - "--lora_directory", - dest="lora_path", - type=Path, - default=Path("loras"), - ) - opt, _ = parser.parse_known_args([f"@{str(initfile)}"]) - return ModelPaths( - models=root / "models", - embeddings=root / str(opt.embedding_path).strip('"'), - loras=root / str(opt.lora_path).strip('"'), - controlnets=root / "controlnets", - ) - - -def _parse_legacy_yamlfile(root: Path, initfile: Path) -> ModelPaths: - """ - Returns tuple of (embedding_path, lora_path, controlnet_path) - """ - # Don't use the config object because it is unforgiving of version updates - # Just use omegaconf directly - opt = OmegaConf.load(initfile) - paths = opt.InvokeAI.Paths - models = paths.get("models_dir", "models") - embeddings = paths.get("embedding_dir", "embeddings") - loras = paths.get("lora_dir", "loras") - controlnets = paths.get("controlnet_dir", "controlnets") - return ModelPaths( - models=root / models if models else None, - embeddings=root / embeddings if embeddings else None, - loras=root / loras if loras else None, - controlnets=root / controlnets if controlnets else None, - ) - - -def get_legacy_embeddings(root: Path) -> ModelPaths: - path = root / "invokeai.init" - if path.exists(): - return _parse_legacy_initfile(root, path) - path = root / "invokeai.yaml" - if path.exists(): - return _parse_legacy_yamlfile(root, path) - - -def do_migrate(src_directory: Path, dest_directory: Path): - """ - Migrate models from src to dest InvokeAI root directories - """ - config_file = dest_directory / "configs" / "models.yaml.3" - dest_models = dest_directory / "models.3" - - version_3 = (dest_directory / "models" / "core").exists() - - # Here we create the destination models.yaml file. - # If we are writing into a version 3 directory and the - # file already exists, then we write into a copy of it to - # avoid deleting its previous customizations. Otherwise we - # create a new empty one. - if version_3: # write into the dest directory - try: - shutil.copy(dest_directory / "configs" / "models.yaml", config_file) - except Exception: - MigrateTo3.initialize_yaml(config_file) - mgr = ModelManager(config_file) # important to initialize BEFORE moving the models directory - (dest_directory / "models").replace(dest_models) - else: - MigrateTo3.initialize_yaml(config_file) - mgr = ModelManager(config_file) - - paths = get_legacy_embeddings(src_directory) - migrator = MigrateTo3(from_root=src_directory, to_models=dest_models, model_manager=mgr, src_paths=paths) - migrator.migrate() - print("Migration successful.") - - if not version_3: - (dest_directory / "models").replace(src_directory / "models.orig") - print(f"Original models directory moved to {dest_directory}/models.orig") - - (dest_directory / "configs" / "models.yaml").replace(src_directory / "configs" / "models.yaml.orig") - print(f"Original models.yaml file moved to {dest_directory}/configs/models.yaml.orig") - - config_file.replace(config_file.with_suffix("")) - dest_models.replace(dest_models.with_suffix("")) - - -def main(): - parser = argparse.ArgumentParser( - prog="invokeai-migrate3", - description=""" -This will copy and convert the models directory and the configs/models.yaml from the InvokeAI 2.3 format -'--from-directory' root to the InvokeAI 3.0 '--to-directory' root. These may be abbreviated '--from' and '--to'.a - -The old models directory and config file will be renamed 'models.orig' and 'models.yaml.orig' respectively. -It is safe to provide the same directory for both arguments, but it is better to use the invokeai_configure -script, which will perform a full upgrade in place.""", - ) - parser.add_argument( - "--from-directory", - dest="src_root", - type=Path, - required=True, - help='Source InvokeAI 2.3 root directory (containing "invokeai.init" or "invokeai.yaml")', - ) - parser.add_argument( - "--to-directory", - dest="dest_root", - type=Path, - required=True, - help='Destination InvokeAI 3.0 directory (containing "invokeai.yaml")', - ) - args = parser.parse_args() - src_root = args.src_root - assert src_root.is_dir(), f"{src_root} is not a valid directory" - assert (src_root / "models").is_dir(), f"{src_root} does not contain a 'models' subdirectory" - assert (src_root / "models" / "hub").exists(), f"{src_root} does not contain a version 2.3 models directory" - assert (src_root / "invokeai.init").exists() or ( - src_root / "invokeai.yaml" - ).exists(), f"{src_root} does not contain an InvokeAI init file." - - dest_root = args.dest_root - assert dest_root.is_dir(), f"{dest_root} is not a valid directory" - config = InvokeAIAppConfig.get_config() - config.parse_args(["--root", str(dest_root)]) - - # TODO: revisit - don't rely on invokeai.yaml to exist yet! - dest_is_setup = (dest_root / "models/core").exists() and (dest_root / "databases").exists() - if not dest_is_setup: - from invokeai.backend.install.invokeai_configure import initialize_rootdir - - initialize_rootdir(dest_root, True) - - do_migrate(src_root, dest_root) - - -if __name__ == "__main__": - main() diff --git a/invokeai/backend/install/model_install_backend.py b/invokeai/backend/install/model_install_backend.py deleted file mode 100644 index fdbe714f62c..00000000000 --- a/invokeai/backend/install/model_install_backend.py +++ /dev/null @@ -1,637 +0,0 @@ -""" -Utility (backend) functions used by model_install.py -""" -import os -import re -import shutil -import warnings -from dataclasses import dataclass, field -from pathlib import Path -from tempfile import TemporaryDirectory -from typing import Callable, Dict, List, Optional, Set, Union - -import requests -import torch -from diffusers import DiffusionPipeline -from diffusers import logging as dlogging -from huggingface_hub import HfApi, HfFolder, hf_hub_url -from omegaconf import OmegaConf -from tqdm import tqdm - -import invokeai.configs as configs -from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_management import AddModelResult, BaseModelType, ModelManager, ModelType, ModelVariantType -from invokeai.backend.model_management.model_probe import ModelProbe, ModelProbeInfo, SchedulerPredictionType -from invokeai.backend.util import download_with_resume -from invokeai.backend.util.devices import choose_torch_device, torch_dtype - -from ..util.logging import InvokeAILogger - -warnings.filterwarnings("ignore") - -# --------------------------globals----------------------- -config = InvokeAIAppConfig.get_config() -logger = InvokeAILogger.get_logger(name="InvokeAI") - -# the initial "configs" dir is now bundled in the `invokeai.configs` package -Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml" - -Config_preamble = """ -# This file describes the alternative machine learning models -# available to InvokeAI script. -# -# To add a new model, follow the examples below. Each -# model requires a model config file, a weights file, -# and the width and height of the images it -# was trained on. -""" - -LEGACY_CONFIGS = { - BaseModelType.StableDiffusion1: { - ModelVariantType.Normal: { - SchedulerPredictionType.Epsilon: "v1-inference.yaml", - SchedulerPredictionType.VPrediction: "v1-inference-v.yaml", - }, - ModelVariantType.Inpaint: { - SchedulerPredictionType.Epsilon: "v1-inpainting-inference.yaml", - SchedulerPredictionType.VPrediction: "v1-inpainting-inference-v.yaml", - }, - }, - BaseModelType.StableDiffusion2: { - ModelVariantType.Normal: { - SchedulerPredictionType.Epsilon: "v2-inference.yaml", - SchedulerPredictionType.VPrediction: "v2-inference-v.yaml", - }, - ModelVariantType.Inpaint: { - SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml", - SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml", - }, - }, - BaseModelType.StableDiffusionXL: { - ModelVariantType.Normal: "sd_xl_base.yaml", - }, - BaseModelType.StableDiffusionXLRefiner: { - ModelVariantType.Normal: "sd_xl_refiner.yaml", - }, -} - - -@dataclass -class InstallSelections: - install_models: List[str] = field(default_factory=list) - remove_models: List[str] = field(default_factory=list) - - -@dataclass -class ModelLoadInfo: - name: str - model_type: ModelType - base_type: BaseModelType - path: Optional[Path] = None - repo_id: Optional[str] = None - subfolder: Optional[str] = None - description: str = "" - installed: bool = False - recommended: bool = False - default: bool = False - requires: Optional[List[str]] = field(default_factory=list) - - -class ModelInstall(object): - def __init__( - self, - config: InvokeAIAppConfig, - prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None, - model_manager: Optional[ModelManager] = None, - access_token: Optional[str] = None, - civitai_api_key: Optional[str] = None, - ): - self.config = config - self.mgr = model_manager or ModelManager(config.model_conf_path) - self.datasets = OmegaConf.load(Dataset_path) - self.prediction_helper = prediction_type_helper - self.access_token = access_token or HfFolder.get_token() - self.civitai_api_key = civitai_api_key or config.civitai_api_key - self.reverse_paths = self._reverse_paths(self.datasets) - - def all_models(self) -> Dict[str, ModelLoadInfo]: - """ - Return dict of model_key=>ModelLoadInfo objects. - This method consolidates and simplifies the entries in both - models.yaml and INITIAL_MODELS.yaml so that they can - be treated uniformly. It also sorts the models alphabetically - by their name, to improve the display somewhat. - """ - model_dict = {} - - # first populate with the entries in INITIAL_MODELS.yaml - for key, value in self.datasets.items(): - name, base, model_type = ModelManager.parse_key(key) - value["name"] = name - value["base_type"] = base - value["model_type"] = model_type - model_info = ModelLoadInfo(**value) - if model_info.subfolder and model_info.repo_id: - model_info.repo_id += f":{model_info.subfolder}" - model_dict[key] = model_info - - # supplement with entries in models.yaml - installed_models = list(self.mgr.list_models()) - - for md in installed_models: - base = md["base_model"] - model_type = md["model_type"] - name = md["model_name"] - key = ModelManager.create_key(name, base, model_type) - if key in model_dict: - model_dict[key].installed = True - else: - model_dict[key] = ModelLoadInfo( - name=name, - base_type=base, - model_type=model_type, - path=value.get("path"), - installed=True, - ) - return {x: model_dict[x] for x in sorted(model_dict.keys(), key=lambda y: model_dict[y].name.lower())} - - def _is_autoloaded(self, model_info: dict) -> bool: - path = model_info.get("path") - if not path: - return False - for autodir in ["autoimport_dir", "lora_dir", "embedding_dir", "controlnet_dir"]: - if autodir_path := getattr(self.config, autodir): - autodir_path = self.config.root_path / autodir_path - if Path(path).is_relative_to(autodir_path): - return True - return False - - def list_models(self, model_type): - installed = self.mgr.list_models(model_type=model_type) - print() - print(f"Installed models of type `{model_type}`:") - print(f"{'Model Key':50} Model Path") - for i in installed: - print(f"{'/'.join([i['base_model'],i['model_type'],i['model_name']]):50} {i['path']}") - print() - - # logic here a little reversed to maintain backward compatibility - def starter_models(self, all_models: bool = False) -> Set[str]: - models = set() - for key, _value in self.datasets.items(): - name, base, model_type = ModelManager.parse_key(key) - if all_models or model_type in [ModelType.Main, ModelType.Vae]: - models.add(key) - return models - - def recommended_models(self) -> Set[str]: - starters = self.starter_models(all_models=True) - return {x for x in starters if self.datasets[x].get("recommended", False)} - - def default_model(self) -> str: - starters = self.starter_models() - defaults = [x for x in starters if self.datasets[x].get("default", False)] - return defaults[0] - - def install(self, selections: InstallSelections): - verbosity = dlogging.get_verbosity() # quench NSFW nags - dlogging.set_verbosity_error() - - job = 1 - jobs = len(selections.remove_models) + len(selections.install_models) - - # remove requested models - for key in selections.remove_models: - name, base, mtype = self.mgr.parse_key(key) - logger.info(f"Deleting {mtype} model {name} [{job}/{jobs}]") - try: - self.mgr.del_model(name, base, mtype) - except FileNotFoundError as e: - logger.warning(e) - job += 1 - - # add requested models - self._remove_installed(selections.install_models) - self._add_required_models(selections.install_models) - for path in selections.install_models: - logger.info(f"Installing {path} [{job}/{jobs}]") - try: - self.heuristic_import(path) - except (ValueError, KeyError) as e: - logger.error(str(e)) - job += 1 - - dlogging.set_verbosity(verbosity) - self.mgr.commit() - - def heuristic_import( - self, - model_path_id_or_url: Union[str, Path], - models_installed: Set[Path] = None, - ) -> Dict[str, AddModelResult]: - """ - :param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL - :param models_installed: Set of installed models, used for recursive invocation - Returns a set of dict objects corresponding to newly-created stanzas in models.yaml. - """ - - if not models_installed: - models_installed = {} - - model_path_id_or_url = str(model_path_id_or_url).strip("\"' ") - - # A little hack to allow nested routines to retrieve info on the requested ID - self.current_id = model_path_id_or_url - path = Path(model_path_id_or_url) - - # fix relative paths - if path.exists() and not path.is_absolute(): - path = path.absolute() # make relative to current WD - - # checkpoint file, or similar - if path.is_file(): - models_installed.update({str(path): self._install_path(path)}) - - # folders style or similar - elif path.is_dir() and any( - (path / x).exists() - for x in { - "config.json", - "model_index.json", - "learned_embeds.bin", - "pytorch_lora_weights.bin", - "pytorch_lora_weights.safetensors", - } - ): - models_installed.update({str(model_path_id_or_url): self._install_path(path)}) - - # recursive scan - elif path.is_dir(): - for child in path.iterdir(): - self.heuristic_import(child, models_installed=models_installed) - - # huggingface repo - elif len(str(model_path_id_or_url).split("/")) == 2: - models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))}) - - # a URL - elif str(model_path_id_or_url).startswith(("http:", "https:", "ftp:")): - models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)}) - - else: - raise KeyError(f"{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping") - - return models_installed - - def _remove_installed(self, model_list: List[str]): - all_models = self.all_models() - models_to_remove = [] - - for path in model_list: - key = self.reverse_paths.get(path) - if key and all_models[key].installed: - models_to_remove.append(path) - - for path in models_to_remove: - logger.warning(f"{path} already installed. Skipping") - model_list.remove(path) - - def _add_required_models(self, model_list: List[str]): - additional_models = [] - all_models = self.all_models() - for path in model_list: - if not (key := self.reverse_paths.get(path)): - continue - for requirement in all_models[key].requires: - requirement_key = self.reverse_paths.get(requirement) - if not all_models[requirement_key].installed: - additional_models.append(requirement) - model_list.extend(additional_models) - - # install a model from a local path. The optional info parameter is there to prevent - # the model from being probed twice in the event that it has already been probed. - def _install_path(self, path: Path, info: ModelProbeInfo = None) -> AddModelResult: - info = info or ModelProbe().heuristic_probe(path, self.prediction_helper) - if not info: - logger.warning(f"Unable to parse format of {path}") - return None - model_name = path.stem if path.is_file() else path.name - if self.mgr.model_exists(model_name, info.base_type, info.model_type): - raise ValueError(f'A model named "{model_name}" is already installed.') - attributes = self._make_attributes(path, info) - return self.mgr.add_model( - model_name=model_name, - base_model=info.base_type, - model_type=info.model_type, - model_attributes=attributes, - ) - - def _install_url(self, url: str) -> AddModelResult: - with TemporaryDirectory(dir=self.config.models_path) as staging: - CIVITAI_RE = r".*civitai.com.*" - civit_url = re.match(CIVITAI_RE, url, re.IGNORECASE) - location = download_with_resume( - url, Path(staging), access_token=self.civitai_api_key if civit_url else None - ) - if not location: - logger.error(f"Unable to download {url}. Skipping.") - info = ModelProbe().heuristic_probe(location, self.prediction_helper) - dest = self.config.models_path / info.base_type.value / info.model_type.value / location.name - dest.parent.mkdir(parents=True, exist_ok=True) - models_path = shutil.move(location, dest) - - # staged version will be garbage-collected at this time - return self._install_path(Path(models_path), info) - - def _install_repo(self, repo_id: str) -> AddModelResult: - # hack to recover models stored in subfolders -- - # Required to get the "v2" model of monster-labs/control_v1p_sd15_qrcode_monster - subfolder = None - if match := re.match(r"^([^/]+/[^/]+):(\w+)$", repo_id): - repo_id = match.group(1) - subfolder = match.group(2) - - hinfo = HfApi().model_info(repo_id) - - # we try to figure out how to download this most economically - # list all the files in the repo - files = [x.rfilename for x in hinfo.siblings] - if subfolder: - files = [x for x in files if x.startswith(f"{subfolder}/")] - prefix = f"{subfolder}/" if subfolder else "" - - location = None - - with TemporaryDirectory(dir=self.config.models_path) as staging: - staging = Path(staging) - if f"{prefix}model_index.json" in files: - location = self._download_hf_pipeline(repo_id, staging, subfolder=subfolder) # pipeline - elif f"{prefix}unet/model.onnx" in files: - location = self._download_hf_model(repo_id, files, staging) - else: - for suffix in ["safetensors", "bin"]: - if f"{prefix}pytorch_lora_weights.{suffix}" in files: - location = self._download_hf_model( - repo_id, [f"pytorch_lora_weights.{suffix}"], staging, subfolder=subfolder - ) # LoRA - break - elif ( - self.config.precision == "float16" and f"{prefix}diffusion_pytorch_model.fp16.{suffix}" in files - ): # vae, controlnet or some other standalone - files = ["config.json", f"diffusion_pytorch_model.fp16.{suffix}"] - location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder) - break - elif f"{prefix}diffusion_pytorch_model.{suffix}" in files: - files = ["config.json", f"diffusion_pytorch_model.{suffix}"] - location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder) - break - elif f"{prefix}learned_embeds.{suffix}" in files: - location = self._download_hf_model( - repo_id, [f"learned_embeds.{suffix}"], staging, subfolder=subfolder - ) - break - elif ( - f"{prefix}image_encoder.txt" in files and f"{prefix}ip_adapter.{suffix}" in files - ): # IP-Adapter - files = ["image_encoder.txt", f"ip_adapter.{suffix}"] - location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder) - break - elif f"{prefix}model.{suffix}" in files and f"{prefix}config.json" in files: - # This elif-condition is pretty fragile, but it is intended to handle CLIP Vision models hosted - # by InvokeAI for use with IP-Adapters. - files = ["config.json", f"model.{suffix}"] - location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder) - break - if not location: - logger.warning(f"Could not determine type of repo {repo_id}. Skipping install.") - return {} - - info = ModelProbe().heuristic_probe(location, self.prediction_helper) - if not info: - logger.warning(f"Could not probe {location}. Skipping install.") - return {} - dest = ( - self.config.models_path - / info.base_type.value - / info.model_type.value - / self._get_model_name(repo_id, location) - ) - if dest.exists(): - shutil.rmtree(dest) - shutil.copytree(location, dest) - return self._install_path(dest, info) - - def _get_model_name(self, path_name: str, location: Path) -> str: - """ - Calculate a name for the model - primitive implementation. - """ - if key := self.reverse_paths.get(path_name): - (name, base, mtype) = ModelManager.parse_key(key) - return name - elif location.is_dir(): - return location.name - else: - return location.stem - - def _make_attributes(self, path: Path, info: ModelProbeInfo) -> dict: - model_name = path.name if path.is_dir() else path.stem - description = f"{info.base_type.value} {info.model_type.value} model {model_name}" - if key := self.reverse_paths.get(self.current_id): - if key in self.datasets: - description = self.datasets[key].get("description") or description - - rel_path = self.relative_to_root(path, self.config.models_path) - - attributes = { - "path": str(rel_path), - "description": str(description), - "model_format": info.format, - } - legacy_conf = None - if info.model_type == ModelType.Main or info.model_type == ModelType.ONNX: - attributes.update( - { - "variant": info.variant_type, - } - ) - if info.format == "checkpoint": - try: - possible_conf = path.with_suffix(".yaml") - if possible_conf.exists(): - legacy_conf = str(self.relative_to_root(possible_conf)) - elif info.base_type in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]: - legacy_conf = Path( - self.config.legacy_conf_dir, - LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type], - ) - else: - legacy_conf = Path( - self.config.legacy_conf_dir, LEGACY_CONFIGS[info.base_type][info.variant_type] - ) - except KeyError: - legacy_conf = Path(self.config.legacy_conf_dir, "v1-inference.yaml") # best guess - - if info.model_type == ModelType.ControlNet and info.format == "checkpoint": - possible_conf = path.with_suffix(".yaml") - if possible_conf.exists(): - legacy_conf = str(self.relative_to_root(possible_conf)) - else: - legacy_conf = Path( - self.config.root_path, - "configs/controlnet", - ("cldm_v15.yaml" if info.base_type == BaseModelType("sd-1") else "cldm_v21.yaml"), - ) - - if legacy_conf: - attributes.update({"config": str(legacy_conf)}) - return attributes - - def relative_to_root(self, path: Path, root: Optional[Path] = None) -> Path: - root = root or self.config.root_path - if path.is_relative_to(root): - return path.relative_to(root) - else: - return path - - def _download_hf_pipeline(self, repo_id: str, staging: Path, subfolder: str = None) -> Path: - """ - Retrieve a StableDiffusion model from cache or remote and then - does a save_pretrained() to the indicated staging area. - """ - _, name = repo_id.split("/") - precision = torch_dtype(choose_torch_device()) - variants = ["fp16", None] if precision == torch.float16 else [None, "fp16"] - - model = None - for variant in variants: - try: - model = DiffusionPipeline.from_pretrained( - repo_id, - variant=variant, - torch_dtype=precision, - safety_checker=None, - subfolder=subfolder, - ) - except Exception as e: # most errors are due to fp16 not being present. Fix this to catch other errors - if "fp16" not in str(e): - print(e) - - if model: - break - - if not model: - logger.error(f"Diffusers model {repo_id} could not be downloaded. Skipping.") - return None - model.save_pretrained(staging / name, safe_serialization=True) - return staging / name - - def _download_hf_model(self, repo_id: str, files: List[str], staging: Path, subfolder: None) -> Path: - _, name = repo_id.split("/") - location = staging / name - paths = [] - for filename in files: - filePath = Path(filename) - p = hf_download_with_resume( - repo_id, - model_dir=location / filePath.parent, - model_name=filePath.name, - access_token=self.access_token, - subfolder=filePath.parent / subfolder if subfolder else filePath.parent, - ) - if p: - paths.append(p) - else: - logger.warning(f"Could not download {filename} from {repo_id}.") - - return location if len(paths) > 0 else None - - @classmethod - def _reverse_paths(cls, datasets) -> dict: - """ - Reverse mapping from repo_id/path to destination name. - """ - return {v.get("path") or v.get("repo_id"): k for k, v in datasets.items()} - - -# ------------------------------------- -def yes_or_no(prompt: str, default_yes=True): - default = "y" if default_yes else "n" - response = input(f"{prompt} [{default}] ") or default - if default_yes: - return response[0] not in ("n", "N") - else: - return response[0] in ("y", "Y") - - -# --------------------------------------------- -def hf_download_from_pretrained(model_class: object, model_name: str, destination: Path, **kwargs): - logger = InvokeAILogger.get_logger("InvokeAI") - logger.addFilter(lambda x: "fp16 is not a valid" not in x.getMessage()) - - model = model_class.from_pretrained( - model_name, - resume_download=True, - **kwargs, - ) - model.save_pretrained(destination, safe_serialization=True) - return destination - - -# --------------------------------------------- -def hf_download_with_resume( - repo_id: str, - model_dir: str, - model_name: str, - model_dest: Path = None, - access_token: str = None, - subfolder: str = None, -) -> Path: - model_dest = model_dest or Path(os.path.join(model_dir, model_name)) - os.makedirs(model_dir, exist_ok=True) - - url = hf_hub_url(repo_id, model_name, subfolder=subfolder) - - header = {"Authorization": f"Bearer {access_token}"} if access_token else {} - open_mode = "wb" - exist_size = 0 - - if os.path.exists(model_dest): - exist_size = os.path.getsize(model_dest) - header["Range"] = f"bytes={exist_size}-" - open_mode = "ab" - - resp = requests.get(url, headers=header, stream=True) - total = int(resp.headers.get("content-length", 0)) - - if resp.status_code == 416: # "range not satisfiable", which means nothing to return - logger.info(f"{model_name}: complete file found. Skipping.") - return model_dest - elif resp.status_code == 404: - logger.warning("File not found") - return None - elif resp.status_code != 200: - logger.warning(f"{model_name}: {resp.reason}") - elif exist_size > 0: - logger.info(f"{model_name}: partial file found. Resuming...") - else: - logger.info(f"{model_name}: Downloading...") - - try: - with ( - open(model_dest, open_mode) as file, - tqdm( - desc=model_name, - initial=exist_size, - total=total + exist_size, - unit="iB", - unit_scale=True, - unit_divisor=1000, - ) as bar, - ): - for data in resp.iter_content(chunk_size=1024): - size = file.write(data) - bar.update(size) - except Exception as e: - logger.error(f"An error occurred while downloading {model_name}: {str(e)}") - return None - return model_dest diff --git a/invokeai/backend/ip_adapter/ip_adapter.py b/invokeai/backend/ip_adapter/ip_adapter.py index 9176bf1f49f..e51966c779c 100644 --- a/invokeai/backend/ip_adapter/ip_adapter.py +++ b/invokeai/backend/ip_adapter/ip_adapter.py @@ -8,8 +8,8 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights -from invokeai.backend.model_management.models.base import calc_model_size_by_data +from ..raw_model import RawModel from .resampler import Resampler @@ -92,7 +92,7 @@ def forward(self, image_embeds): return clip_extra_context_tokens -class IPAdapter: +class IPAdapter(RawModel): """IP-Adapter: https://arxiv.org/pdf/2308.06721.pdf""" def __init__( @@ -124,6 +124,9 @@ def to(self, device: torch.device, dtype: Optional[torch.dtype] = None): self.attn_weights.to(device=self.device, dtype=self.dtype) def calc_size(self): + # workaround for circular import + from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data + return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights) def _init_image_proj_model(self, state_dict): diff --git a/invokeai/backend/model_management/models/lora.py b/invokeai/backend/lora.py similarity index 81% rename from invokeai/backend/model_management/models/lora.py rename to invokeai/backend/lora.py index b110d75d220..0b7128034a2 100644 --- a/invokeai/backend/model_management/models/lora.py +++ b/invokeai/backend/lora.py @@ -1,98 +1,17 @@ +# Copyright (c) 2024 The InvokeAI Development team +"""LoRA model support.""" + import bisect -import os -from enum import Enum from pathlib import Path -from typing import Dict, Optional, Union +from typing import Dict, List, Optional, Tuple, Union import torch from safetensors.torch import load_file +from typing_extensions import Self -from .base import ( - BaseModelType, - InvalidModelException, - ModelBase, - ModelConfigBase, - ModelNotFoundException, - ModelType, - SubModelType, - classproperty, -) - - -class LoRAModelFormat(str, Enum): - LyCORIS = "lycoris" - Diffusers = "diffusers" - - -class LoRAModel(ModelBase): - # model_size: int - - class Config(ModelConfigBase): - model_format: LoRAModelFormat # TODO: - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert model_type == ModelType.Lora - super().__init__(model_path, base_model, model_type) - - self.model_size = os.path.getsize(self.model_path) - - def get_size(self, child_type: Optional[SubModelType] = None): - if child_type is not None: - raise Exception("There is no child models in lora") - return self.model_size - - def get_model( - self, - torch_dtype: Optional[torch.dtype], - child_type: Optional[SubModelType] = None, - ): - if child_type is not None: - raise Exception("There is no child models in lora") - - model = LoRAModelRaw.from_checkpoint( - file_path=self.model_path, - dtype=torch_dtype, - base_model=self.base_model, - ) - - self.model_size = model.calc_size() - return model - - @classproperty - def save_to_config(cls) -> bool: - return True - - @classmethod - def detect_format(cls, path: str): - if not os.path.exists(path): - raise ModelNotFoundException() - - if os.path.isdir(path): - for ext in ["safetensors", "bin"]: - if os.path.exists(os.path.join(path, f"pytorch_lora_weights.{ext}")): - return LoRAModelFormat.Diffusers - - if os.path.isfile(path): - if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]): - return LoRAModelFormat.LyCORIS - - raise InvalidModelException(f"Not a valid model: {path}") +from invokeai.backend.model_manager import BaseModelType - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, - ) -> str: - if cls.detect_format(model_path) == LoRAModelFormat.Diffusers: - for ext in ["safetensors", "bin"]: # return path to the safetensors file inside the folder - path = Path(model_path, f"pytorch_lora_weights.{ext}") - if path.exists(): - return path - else: - return model_path +from .raw_model import RawModel class LoRALayerBase: @@ -108,7 +27,7 @@ class LoRALayerBase: def __init__( self, layer_key: str, - values: dict, + values: Dict[str, torch.Tensor], ): if "alpha" in values: self.alpha = values["alpha"].item() @@ -116,7 +35,7 @@ def __init__( self.alpha = None if "bias_indices" in values and "bias_values" in values and "bias_size" in values: - self.bias = torch.sparse_coo_tensor( + self.bias: Optional[torch.Tensor] = torch.sparse_coo_tensor( values["bias_indices"], values["bias_values"], tuple(values["bias_size"]), @@ -128,7 +47,7 @@ def __init__( self.rank = None # set in layer implementation self.layer_key = layer_key - def get_weight(self, orig_weight: torch.Tensor): + def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: raise NotImplementedError() def calc_size(self) -> int: @@ -142,7 +61,7 @@ def to( self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, - ): + ) -> None: if self.bias is not None: self.bias = self.bias.to(device=device, dtype=dtype) @@ -156,20 +75,20 @@ class LoRALayer(LoRALayerBase): def __init__( self, layer_key: str, - values: dict, + values: Dict[str, torch.Tensor], ): super().__init__(layer_key, values) self.up = values["lora_up.weight"] self.down = values["lora_down.weight"] if "lora_mid.weight" in values: - self.mid = values["lora_mid.weight"] + self.mid: Optional[torch.Tensor] = values["lora_mid.weight"] else: self.mid = None self.rank = self.down.shape[0] - def get_weight(self, orig_weight: torch.Tensor): + def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: if self.mid is not None: up = self.up.reshape(self.up.shape[0], self.up.shape[1]) down = self.down.reshape(self.down.shape[0], self.down.shape[1]) @@ -190,7 +109,7 @@ def to( self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, - ): + ) -> None: super().to(device=device, dtype=dtype) self.up = self.up.to(device=device, dtype=dtype) @@ -208,11 +127,7 @@ class LoHALayer(LoRALayerBase): # t1: Optional[torch.Tensor] = None # t2: Optional[torch.Tensor] = None - def __init__( - self, - layer_key: str, - values: dict, - ): + def __init__(self, layer_key: str, values: Dict[str, torch.Tensor]): super().__init__(layer_key, values) self.w1_a = values["hada_w1_a"] @@ -221,20 +136,20 @@ def __init__( self.w2_b = values["hada_w2_b"] if "hada_t1" in values: - self.t1 = values["hada_t1"] + self.t1: Optional[torch.Tensor] = values["hada_t1"] else: self.t1 = None if "hada_t2" in values: - self.t2 = values["hada_t2"] + self.t2: Optional[torch.Tensor] = values["hada_t2"] else: self.t2 = None self.rank = self.w1_b.shape[0] - def get_weight(self, orig_weight: torch.Tensor): + def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: if self.t1 is None: - weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b) + weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b) else: rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a) @@ -254,7 +169,7 @@ def to( self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, - ): + ) -> None: super().to(device=device, dtype=dtype) self.w1_a = self.w1_a.to(device=device, dtype=dtype) @@ -280,12 +195,12 @@ class LoKRLayer(LoRALayerBase): def __init__( self, layer_key: str, - values: dict, + values: Dict[str, torch.Tensor], ): super().__init__(layer_key, values) if "lokr_w1" in values: - self.w1 = values["lokr_w1"] + self.w1: Optional[torch.Tensor] = values["lokr_w1"] self.w1_a = None self.w1_b = None else: @@ -294,7 +209,7 @@ def __init__( self.w1_b = values["lokr_w1_b"] if "lokr_w2" in values: - self.w2 = values["lokr_w2"] + self.w2: Optional[torch.Tensor] = values["lokr_w2"] self.w2_a = None self.w2_b = None else: @@ -303,7 +218,7 @@ def __init__( self.w2_b = values["lokr_w2_b"] if "lokr_t2" in values: - self.t2 = values["lokr_t2"] + self.t2: Optional[torch.Tensor] = values["lokr_t2"] else: self.t2 = None @@ -314,14 +229,18 @@ def __init__( else: self.rank = None # unscaled - def get_weight(self, orig_weight: torch.Tensor): - w1 = self.w1 + def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: + w1: Optional[torch.Tensor] = self.w1 if w1 is None: + assert self.w1_a is not None + assert self.w1_b is not None w1 = self.w1_a @ self.w1_b w2 = self.w2 if w2 is None: if self.t2 is None: + assert self.w2_a is not None + assert self.w2_b is not None w2 = self.w2_a @ self.w2_b else: w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b) @@ -329,6 +248,8 @@ def get_weight(self, orig_weight: torch.Tensor): if len(w2.shape) == 4: w1 = w1.unsqueeze(2).unsqueeze(2) w2 = w2.contiguous() + assert w1 is not None + assert w2 is not None weight = torch.kron(w1, w2) return weight @@ -344,18 +265,22 @@ def to( self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, - ): + ) -> None: super().to(device=device, dtype=dtype) if self.w1 is not None: self.w1 = self.w1.to(device=device, dtype=dtype) else: + assert self.w1_a is not None + assert self.w1_b is not None self.w1_a = self.w1_a.to(device=device, dtype=dtype) self.w1_b = self.w1_b.to(device=device, dtype=dtype) if self.w2 is not None: self.w2 = self.w2.to(device=device, dtype=dtype) else: + assert self.w2_a is not None + assert self.w2_b is not None self.w2_a = self.w2_a.to(device=device, dtype=dtype) self.w2_b = self.w2_b.to(device=device, dtype=dtype) @@ -369,7 +294,7 @@ class FullLayer(LoRALayerBase): def __init__( self, layer_key: str, - values: dict, + values: Dict[str, torch.Tensor], ): super().__init__(layer_key, values) @@ -382,7 +307,7 @@ def __init__( self.rank = None # unscaled - def get_weight(self, orig_weight: torch.Tensor): + def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: return self.weight def calc_size(self) -> int: @@ -394,7 +319,7 @@ def to( self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, - ): + ) -> None: super().to(device=device, dtype=dtype) self.weight = self.weight.to(device=device, dtype=dtype) @@ -407,7 +332,7 @@ class IA3Layer(LoRALayerBase): def __init__( self, layer_key: str, - values: dict, + values: Dict[str, torch.Tensor], ): super().__init__(layer_key, values) @@ -416,10 +341,11 @@ def __init__( self.rank = None # unscaled - def get_weight(self, orig_weight: torch.Tensor): + def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: weight = self.weight if not self.on_input: weight = weight.reshape(-1, 1) + assert orig_weight is not None return orig_weight * weight def calc_size(self) -> int: @@ -439,28 +365,30 @@ def to( self.on_input = self.on_input.to(device=device, dtype=dtype) -# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix -class LoRAModelRaw: # (torch.nn.Module): +AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer] + + +class LoRAModelRaw(RawModel): # (torch.nn.Module): _name: str - layers: Dict[str, LoRALayer] + layers: Dict[str, AnyLoRALayer] def __init__( self, name: str, - layers: Dict[str, LoRALayer], + layers: Dict[str, AnyLoRALayer], ): self._name = name self.layers = layers @property - def name(self): + def name(self) -> str: return self._name def to( self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, - ): + ) -> None: # TODO: try revert if exception? for _key, layer in self.layers.items(): layer.to(device=device, dtype=dtype) @@ -472,7 +400,7 @@ def calc_size(self) -> int: return model_size @classmethod - def _convert_sdxl_keys_to_diffusers_format(cls, state_dict): + def _convert_sdxl_keys_to_diffusers_format(cls, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Convert the keys of an SDXL LoRA state_dict to diffusers format. The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in @@ -536,7 +464,7 @@ def from_checkpoint( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, base_model: Optional[BaseModelType] = None, - ): + ) -> Self: device = device or torch.device("cpu") dtype = dtype or torch.float32 @@ -544,16 +472,16 @@ def from_checkpoint( file_path = Path(file_path) model = cls( - name=file_path.stem, # TODO: + name=file_path.stem, layers={}, ) if file_path.suffix == ".safetensors": - state_dict = load_file(file_path.absolute().as_posix(), device="cpu") + sd = load_file(file_path.absolute().as_posix(), device="cpu") else: - state_dict = torch.load(file_path, map_location="cpu") + sd = torch.load(file_path, map_location="cpu") - state_dict = cls._group_state(state_dict) + state_dict = cls._group_state(sd) if base_model == BaseModelType.StableDiffusionXL: state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict) @@ -561,7 +489,7 @@ def from_checkpoint( for layer_key, values in state_dict.items(): # lora and locon if "lora_down.weight" in values: - layer = LoRALayer(layer_key, values) + layer: AnyLoRALayer = LoRALayer(layer_key, values) # loha elif "hada_w1_b" in values: @@ -592,8 +520,8 @@ def from_checkpoint( return model @staticmethod - def _group_state(state_dict: dict): - state_dict_groupped = {} + def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]: + state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {} for key, value in state_dict.items(): stem, leaf = key.split(".", 1) @@ -606,7 +534,7 @@ def _group_state(state_dict: dict): # code from # https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32 -def make_sdxl_unet_conversion_map(): +def make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]: """Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format.""" unet_conversion_map_layer = [] diff --git a/invokeai/backend/model_management/README.md b/invokeai/backend/model_management/README.md deleted file mode 100644 index 0d94f39642e..00000000000 --- a/invokeai/backend/model_management/README.md +++ /dev/null @@ -1,27 +0,0 @@ -# Model Cache - -## `glibc` Memory Allocator Fragmentation - -Python (and PyTorch) relies on the memory allocator from the C Standard Library (`libc`). On linux, with the GNU C Standard Library implementation (`glibc`), our memory access patterns have been observed to cause severe memory fragmentation. This fragmentation results in large amounts of memory that has been freed but can't be released back to the OS. Loading models from disk and moving them between CPU/CUDA seem to be the operations that contribute most to the fragmentation. This memory fragmentation issue can result in OOM crashes during frequent model switching, even if `max_cache_size` is set to a reasonable value (e.g. a OOM crash with `max_cache_size=16` on a system with 32GB of RAM). - -This problem may also exist on other OSes, and other `libc` implementations. But, at the time of writing, it has only been investigated on linux with `glibc`. - -To better understand how the `glibc` memory allocator works, see these references: -- Basics: https://www.gnu.org/software/libc/manual/html_node/The-GNU-Allocator.html -- Details: https://sourceware.org/glibc/wiki/MallocInternals - -Note the differences between memory allocated as chunks in an arena vs. memory allocated with `mmap`. Under `glibc`'s default configuration, most model tensors get allocated as chunks in an arena making them vulnerable to the problem of fragmentation. - -We can work around this memory fragmentation issue by setting the following env var: - -```bash -# Force blocks >1MB to be allocated with `mmap` so that they are released to the system immediately when they are freed. -MALLOC_MMAP_THRESHOLD_=1048576 -``` - -See the following references for more information about the `malloc` tunable parameters: -- https://www.gnu.org/software/libc/manual/html_node/Malloc-Tunable-Parameters.html -- https://www.gnu.org/software/libc/manual/html_node/Memory-Allocation-Tunables.html -- https://man7.org/linux/man-pages/man3/mallopt.3.html - -The model cache emits debug logs that provide visibility into the state of the `libc` memory allocator. See the `LibcUtil` class for more info on how these `libc` malloc stats are collected. diff --git a/invokeai/backend/model_management/__init__.py b/invokeai/backend/model_management/__init__.py deleted file mode 100644 index 03abf58eb46..00000000000 --- a/invokeai/backend/model_management/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# ruff: noqa: I001, F401 -""" -Initialization file for invokeai.backend.model_management -""" -# This import must be first -from .model_manager import AddModelResult, ModelInfo, ModelManager, SchedulerPredictionType -from .lora import ModelPatcher, ONNXModelPatcher -from .model_cache import ModelCache - -from .models import ( - BaseModelType, - DuplicateModelException, - ModelNotFoundException, - ModelType, - ModelVariantType, - SubModelType, -) - -# This import must be last -from .model_merge import MergeInterpolationMethod, ModelMerger diff --git a/invokeai/backend/model_management/detect_baked_in_vae.py b/invokeai/backend/model_management/detect_baked_in_vae.py deleted file mode 100644 index 9118438548d..00000000000 --- a/invokeai/backend/model_management/detect_baked_in_vae.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright (c) 2024 Lincoln Stein and the InvokeAI Development Team -""" -This module exports the function has_baked_in_sdxl_vae(). -It returns True if an SDXL checkpoint model has the original SDXL 1.0 VAE, -which doesn't work properly in fp16 mode. -""" - -import hashlib -from pathlib import Path - -from safetensors.torch import load_file - -SDXL_1_0_VAE_HASH = "bc40b16c3a0fa4625abdfc01c04ffc21bf3cefa6af6c7768ec61eb1f1ac0da51" - - -def has_baked_in_sdxl_vae(checkpoint_path: Path) -> bool: - """Return true if the checkpoint contains a custom (non SDXL-1.0) VAE.""" - hash = _vae_hash(checkpoint_path) - return hash != SDXL_1_0_VAE_HASH - - -def _vae_hash(checkpoint_path: Path) -> str: - checkpoint = load_file(checkpoint_path, device="cpu") - vae_keys = [x for x in checkpoint.keys() if x.startswith("first_stage_model.")] - hash = hashlib.new("sha256") - for key in vae_keys: - value = checkpoint[key] - hash.update(bytes(key, "UTF-8")) - hash.update(bytes(str(value), "UTF-8")) - - return hash.hexdigest() diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_management/model_cache.py deleted file mode 100644 index 2a7f4b5a95e..00000000000 --- a/invokeai/backend/model_management/model_cache.py +++ /dev/null @@ -1,553 +0,0 @@ -""" -Manage a RAM cache of diffusion/transformer models for fast switching. -They are moved between GPU VRAM and CPU RAM as necessary. If the cache -grows larger than a preset maximum, then the least recently used -model will be cleared and (re)loaded from disk when next needed. - -The cache returns context manager generators designed to load the -model into the GPU within the context, and unload outside the -context. Use like this: - - cache = ModelCache(max_cache_size=7.5) - with cache.get_model('runwayml/stable-diffusion-1-5') as SD1, - cache.get_model('stabilityai/stable-diffusion-2') as SD2: - do_something_in_GPU(SD1,SD2) - - -""" - -import gc -import hashlib -import math -import os -import sys -import time -from contextlib import suppress -from dataclasses import dataclass, field -from pathlib import Path -from typing import Any, Dict, Optional, Type, Union, types - -import torch - -import invokeai.backend.util.logging as logger -from invokeai.backend.model_management.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff -from invokeai.backend.model_management.model_load_optimizations import skip_torch_weight_init - -from ..util.devices import choose_torch_device -from .models import BaseModelType, ModelBase, ModelType, SubModelType - -if choose_torch_device() == torch.device("mps"): - from torch import mps - -# Maximum size of the cache, in gigs -# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously -DEFAULT_MAX_CACHE_SIZE = 6.0 - -# amount of GPU memory to hold in reserve for use by generations (GB) -DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75 - -# actual size of a gig -GIG = 1073741824 -# Size of a MB in bytes. -MB = 2**20 - - -@dataclass -class CacheStats(object): - hits: int = 0 # cache hits - misses: int = 0 # cache misses - high_watermark: int = 0 # amount of cache used - in_cache: int = 0 # number of models in cache - cleared: int = 0 # number of models cleared to make space - cache_size: int = 0 # total size of cache - # {submodel_key => size} - loaded_model_sizes: Dict[str, int] = field(default_factory=dict) - - -class ModelLocker(object): - "Forward declaration" - - pass - - -class ModelCache(object): - "Forward declaration" - - pass - - -class _CacheRecord: - size: int - model: Any - cache: ModelCache - _locks: int - - def __init__(self, cache, model: Any, size: int): - self.size = size - self.model = model - self.cache = cache - self._locks = 0 - - def lock(self): - self._locks += 1 - - def unlock(self): - self._locks -= 1 - assert self._locks >= 0 - - @property - def locked(self): - return self._locks > 0 - - @property - def loaded(self): - if self.model is not None and hasattr(self.model, "device"): - return self.model.device != self.cache.storage_device - else: - return False - - -class ModelCache(object): - def __init__( - self, - max_cache_size: float = DEFAULT_MAX_CACHE_SIZE, - max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE, - execution_device: torch.device = torch.device("cuda"), - storage_device: torch.device = torch.device("cpu"), - precision: torch.dtype = torch.float16, - sequential_offload: bool = False, - lazy_offloading: bool = True, - sha_chunksize: int = 16777216, - logger: types.ModuleType = logger, - log_memory_usage: bool = False, - ): - """ - :param max_cache_size: Maximum size of the RAM cache [6.0 GB] - :param execution_device: Torch device to load active model into [torch.device('cuda')] - :param storage_device: Torch device to save inactive model in [torch.device('cpu')] - :param precision: Precision for loaded models [torch.float16] - :param lazy_offloading: Keep model in VRAM until another model needs to be loaded - :param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially - :param sha_chunksize: Chunksize to use when calculating sha256 model hash - :param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache - operation, and the result will be logged (at debug level). There is a time cost to capturing the memory - snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's - behaviour. - """ - self.model_infos: Dict[str, ModelBase] = {} - # allow lazy offloading only when vram cache enabled - self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0 - self.precision: torch.dtype = precision - self.max_cache_size: float = max_cache_size - self.max_vram_cache_size: float = max_vram_cache_size - self.execution_device: torch.device = execution_device - self.storage_device: torch.device = storage_device - self.sha_chunksize = sha_chunksize - self.logger = logger - self._log_memory_usage = log_memory_usage - - # used for stats collection - self.stats = None - - self._cached_models = {} - self._cache_stack = [] - - def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]: - if self._log_memory_usage: - return MemorySnapshot.capture() - return None - - def get_key( - self, - model_path: str, - base_model: BaseModelType, - model_type: ModelType, - submodel_type: Optional[SubModelType] = None, - ): - key = f"{model_path}:{base_model}:{model_type}" - if submodel_type: - key += f":{submodel_type}" - return key - - def _get_model_info( - self, - model_path: str, - model_class: Type[ModelBase], - base_model: BaseModelType, - model_type: ModelType, - ): - model_info_key = self.get_key( - model_path=model_path, - base_model=base_model, - model_type=model_type, - submodel_type=None, - ) - - if model_info_key not in self.model_infos: - self.model_infos[model_info_key] = model_class( - model_path, - base_model, - model_type, - ) - - return self.model_infos[model_info_key] - - # TODO: args - def get_model( - self, - model_path: Union[str, Path], - model_class: Type[ModelBase], - base_model: BaseModelType, - model_type: ModelType, - submodel: Optional[SubModelType] = None, - gpu_load: bool = True, - ) -> Any: - if not isinstance(model_path, Path): - model_path = Path(model_path) - - if not os.path.exists(model_path): - raise Exception(f"Model not found: {model_path}") - - model_info = self._get_model_info( - model_path=model_path, - model_class=model_class, - base_model=base_model, - model_type=model_type, - ) - key = self.get_key( - model_path=model_path, - base_model=base_model, - model_type=model_type, - submodel_type=submodel, - ) - # TODO: lock for no copies on simultaneous calls? - cache_entry = self._cached_models.get(key, None) - if cache_entry is None: - self.logger.info( - f"Loading model {model_path}, type" - f" {base_model.value}:{model_type.value}{':'+submodel.value if submodel else ''}" - ) - if self.stats: - self.stats.misses += 1 - - self_reported_model_size_before_load = model_info.get_size(submodel) - # Remove old models from the cache to make room for the new model. - self._make_cache_room(self_reported_model_size_before_load) - - # Load the model from disk and capture a memory snapshot before/after. - start_load_time = time.time() - snapshot_before = self._capture_memory_snapshot() - with skip_torch_weight_init(): - model = model_info.get_model(child_type=submodel, torch_dtype=self.precision) - snapshot_after = self._capture_memory_snapshot() - end_load_time = time.time() - - self_reported_model_size_after_load = model_info.get_size(submodel) - - self.logger.debug( - f"Moved model '{key}' from disk to cpu in {(end_load_time-start_load_time):.2f}s.\n" - f"Self-reported size before/after load: {(self_reported_model_size_before_load/GIG):.3f}GB /" - f" {(self_reported_model_size_after_load/GIG):.3f}GB.\n" - f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" - ) - - if abs(self_reported_model_size_after_load - self_reported_model_size_before_load) > 10 * MB: - self.logger.debug( - f"Model '{key}' mis-reported its size before load. Self-reported size before/after load:" - f" {(self_reported_model_size_before_load/GIG):.2f}GB /" - f" {(self_reported_model_size_after_load/GIG):.2f}GB." - ) - - cache_entry = _CacheRecord(self, model, self_reported_model_size_after_load) - self._cached_models[key] = cache_entry - else: - if self.stats: - self.stats.hits += 1 - - if self.stats: - self.stats.cache_size = self.max_cache_size * GIG - self.stats.high_watermark = max(self.stats.high_watermark, self._cache_size()) - self.stats.in_cache = len(self._cached_models) - self.stats.loaded_model_sizes[key] = max( - self.stats.loaded_model_sizes.get(key, 0), model_info.get_size(submodel) - ) - - with suppress(Exception): - self._cache_stack.remove(key) - self._cache_stack.append(key) - - return self.ModelLocker(self, key, cache_entry.model, gpu_load, cache_entry.size) - - def _move_model_to_device(self, key: str, target_device: torch.device): - cache_entry = self._cached_models[key] - - source_device = cache_entry.model.device - # Note: We compare device types only so that 'cuda' == 'cuda:0'. This would need to be revised to support - # multi-GPU. - if torch.device(source_device).type == torch.device(target_device).type: - return - - start_model_to_time = time.time() - snapshot_before = self._capture_memory_snapshot() - cache_entry.model.to(target_device) - snapshot_after = self._capture_memory_snapshot() - end_model_to_time = time.time() - self.logger.debug( - f"Moved model '{key}' from {source_device} to" - f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s.\n" - f"Estimated model size: {(cache_entry.size/GIG):.3f} GB.\n" - f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" - ) - - if ( - snapshot_before is not None - and snapshot_after is not None - and snapshot_before.vram is not None - and snapshot_after.vram is not None - ): - vram_change = abs(snapshot_before.vram - snapshot_after.vram) - - # If the estimated model size does not match the change in VRAM, log a warning. - if not math.isclose( - vram_change, - cache_entry.size, - rel_tol=0.1, - abs_tol=10 * MB, - ): - self.logger.debug( - f"Moving model '{key}' from {source_device} to" - f" {target_device} caused an unexpected change in VRAM usage. The model's" - " estimated size may be incorrect. Estimated model size:" - f" {(cache_entry.size/GIG):.3f} GB.\n" - f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" - ) - - class ModelLocker(object): - def __init__(self, cache, key, model, gpu_load, size_needed): - """ - :param cache: The model_cache object - :param key: The key of the model to lock in GPU - :param model: The model to lock - :param gpu_load: True if load into gpu - :param size_needed: Size of the model to load - """ - self.gpu_load = gpu_load - self.cache = cache - self.key = key - self.model = model - self.size_needed = size_needed - self.cache_entry = self.cache._cached_models[self.key] - - def __enter__(self) -> Any: - if not hasattr(self.model, "to"): - return self.model - - # NOTE that the model has to have the to() method in order for this - # code to move it into GPU! - if self.gpu_load: - self.cache_entry.lock() - - try: - if self.cache.lazy_offloading: - self.cache._offload_unlocked_models(self.size_needed) - - self.cache._move_model_to_device(self.key, self.cache.execution_device) - - self.cache.logger.debug(f"Locking {self.key} in {self.cache.execution_device}") - self.cache._print_cuda_stats() - - except Exception: - self.cache_entry.unlock() - raise - - # TODO: not fully understand - # in the event that the caller wants the model in RAM, we - # move it into CPU if it is in GPU and not locked - elif self.cache_entry.loaded and not self.cache_entry.locked: - self.cache._move_model_to_device(self.key, self.cache.storage_device) - - return self.model - - def __exit__(self, type, value, traceback): - if not hasattr(self.model, "to"): - return - - self.cache_entry.unlock() - if not self.cache.lazy_offloading: - self.cache._offload_unlocked_models() - self.cache._print_cuda_stats() - - # TODO: should it be called untrack_model? - def uncache_model(self, cache_id: str): - with suppress(ValueError): - self._cache_stack.remove(cache_id) - self._cached_models.pop(cache_id, None) - - def model_hash( - self, - model_path: Union[str, Path], - ) -> str: - """ - Given the HF repo id or path to a model on disk, returns a unique - hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs - - :param model_path: Path to model file/directory on disk. - """ - return self._local_model_hash(model_path) - - def cache_size(self) -> float: - """Return the current size of the cache, in GB.""" - return self._cache_size() / GIG - - def _has_cuda(self) -> bool: - return self.execution_device.type == "cuda" - - def _print_cuda_stats(self): - vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG) - ram = "%4.2fG" % self.cache_size() - - cached_models = 0 - loaded_models = 0 - locked_models = 0 - for model_info in self._cached_models.values(): - cached_models += 1 - if model_info.loaded: - loaded_models += 1 - if model_info.locked: - locked_models += 1 - - self.logger.debug( - f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ =" - f" {cached_models}/{loaded_models}/{locked_models}" - ) - - def _cache_size(self) -> int: - return sum([m.size for m in self._cached_models.values()]) - - def _make_cache_room(self, model_size): - # calculate how much memory this model will require - # multiplier = 2 if self.precision==torch.float32 else 1 - bytes_needed = model_size - maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes - current_size = self._cache_size() - - if current_size + bytes_needed > maximum_size: - self.logger.debug( - f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional" - f" {(bytes_needed/GIG):.2f} GB" - ) - - self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}") - - pos = 0 - models_cleared = 0 - while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack): - model_key = self._cache_stack[pos] - cache_entry = self._cached_models[model_key] - - refs = sys.getrefcount(cache_entry.model) - - # HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly - # going against the advice in the Python docs by using `gc.get_referrers(...)` in this way: - # https://docs.python.org/3/library/gc.html#gc.get_referrers - - # manualy clear local variable references of just finished function calls - # for some reason python don't want to collect it even by gc.collect() immidiately - if refs > 2: - while True: - cleared = False - for referrer in gc.get_referrers(cache_entry.model): - if type(referrer).__name__ == "frame": - # RuntimeError: cannot clear an executing frame - with suppress(RuntimeError): - referrer.clear() - cleared = True - # break - - # repeat if referrers changes(due to frame clear), else exit loop - if cleared: - gc.collect() - else: - break - - device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None - self.logger.debug( - f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}," - f" refs: {refs}" - ) - - # Expected refs: - # 1 from cache_entry - # 1 from getrefcount function - # 1 from onnx runtime object - if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2): - self.logger.debug( - f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)" - ) - current_size -= cache_entry.size - models_cleared += 1 - if self.stats: - self.stats.cleared += 1 - del self._cache_stack[pos] - del self._cached_models[model_key] - del cache_entry - - else: - pos += 1 - - if models_cleared > 0: - # There would likely be some 'garbage' to be collected regardless of whether a model was cleared or not, but - # there is a significant time cost to calling `gc.collect()`, so we want to use it sparingly. (The time cost - # is high even if no garbage gets collected.) - # - # Calling gc.collect(...) when a model is cleared seems like a good middle-ground: - # - If models had to be cleared, it's a signal that we are close to our memory limit. - # - If models were cleared, there's a good chance that there's a significant amount of garbage to be - # collected. - # - # Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up - # immediately when their reference count hits 0. - gc.collect() - - torch.cuda.empty_cache() - if choose_torch_device() == torch.device("mps"): - mps.empty_cache() - - self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}") - - def _offload_unlocked_models(self, size_needed: int = 0): - reserved = self.max_vram_cache_size * GIG - vram_in_use = torch.cuda.memory_allocated() - self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB") - for model_key, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size): - if vram_in_use <= reserved: - break - if not cache_entry.locked and cache_entry.loaded: - self._move_model_to_device(model_key, self.storage_device) - - vram_in_use = torch.cuda.memory_allocated() - self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB") - - torch.cuda.empty_cache() - if choose_torch_device() == torch.device("mps"): - mps.empty_cache() - - def _local_model_hash(self, model_path: Union[str, Path]) -> str: - sha = hashlib.sha256() - path = Path(model_path) - - hashpath = path / "checksum.sha256" - if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime: - with open(hashpath) as f: - hash = f.read() - return hash - - self.logger.debug(f"computing hash of model {path.name}") - for file in list(path.rglob("*.ckpt")) + list(path.rglob("*.safetensors")) + list(path.rglob("*.pth")): - with open(file, "rb") as f: - while chunk := f.read(self.sha_chunksize): - sha.update(chunk) - hash = sha.hexdigest() - with open(hashpath, "w") as f: - f.write(hash) - return hash diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py deleted file mode 100644 index 362d8d3ff55..00000000000 --- a/invokeai/backend/model_management/model_manager.py +++ /dev/null @@ -1,1121 +0,0 @@ -"""This module manages the InvokeAI `models.yaml` file, mapping -symbolic diffusers model names to the paths and repo_ids used by the -underlying `from_pretrained()` call. - -SYNOPSIS: - - mgr = ModelManager('/home/phi/invokeai/configs/models.yaml') - sd1_5 = mgr.get_model('stable-diffusion-v1-5', - model_type=ModelType.Main, - base_model=BaseModelType.StableDiffusion1, - submodel_type=SubModelType.Unet) - with sd1_5 as unet: - run_some_inference(unet) - -FETCHING MODELS: - -Models are described using four attributes: - - 1) model_name -- the symbolic name for the model - - 2) ModelType -- an enum describing the type of the model. Currently - defined types are: - ModelType.Main -- a full model capable of generating images - ModelType.Vae -- a VAE model - ModelType.Lora -- a LoRA or LyCORIS fine-tune - ModelType.TextualInversion -- a textual inversion embedding - ModelType.ControlNet -- a ControlNet model - ModelType.IPAdapter -- an IPAdapter model - - 3) BaseModelType -- an enum indicating the stable diffusion base model, one of: - BaseModelType.StableDiffusion1 - BaseModelType.StableDiffusion2 - - 4) SubModelType (optional) -- an enum that refers to one of the submodels contained - within the main model. Values are: - - SubModelType.UNet - SubModelType.TextEncoder - SubModelType.Tokenizer - SubModelType.Scheduler - SubModelType.SafetyChecker - -To fetch a model, use `manager.get_model()`. This takes the symbolic -name of the model, the ModelType, the BaseModelType and the -SubModelType. The latter is required for ModelType.Main. - -get_model() will return a ModelInfo object that can then be used in -context to retrieve the model and move it into GPU VRAM (on GPU -systems). - -A typical example is: - - sd1_5 = mgr.get_model('stable-diffusion-v1-5', - model_type=ModelType.Main, - base_model=BaseModelType.StableDiffusion1, - submodel_type=SubModelType.UNet) - with sd1_5 as unet: - run_some_inference(unet) - -The ModelInfo object provides a number of useful fields describing the -model, including: - - name -- symbolic name of the model - base_model -- base model (BaseModelType) - type -- model type (ModelType) - location -- path to the model file - precision -- torch precision of the model - hash -- unique sha256 checksum for this model - -SUBMODELS: - -When fetching a main model, you must specify the submodel. Retrieval -of full pipelines is not supported. - - vae_info = mgr.get_model('stable-diffusion-1.5', - model_type = ModelType.Main, - base_model = BaseModelType.StableDiffusion1, - submodel_type = SubModelType.Vae - ) - with vae_info as vae: - do_something(vae) - -This rule does not apply to controlnets, embeddings, loras and standalone -VAEs, which do not have submodels. - -LISTING MODELS - -The model_names() method will return a list of Tuples describing each -model it knows about: - - >> mgr.model_names() - [ - ('stable-diffusion-1.5', , ), - ('stable-diffusion-2.1', , ), - ('inpaint', , ) - ('Ink scenery', , ) - ... - ] - -The tuple is in the correct order to pass to get_model(): - - for m in mgr.model_names(): - info = get_model(*m) - -In contrast, the list_models() method returns a list of dicts, each -providing information about a model defined in models.yaml. For example: - - >>> models = mgr.list_models() - >>> json.dumps(models[0]) - {"path": "/home/lstein/invokeai-main/models/sd-1/controlnet/canny", - "model_format": "diffusers", - "name": "canny", - "base_model": "sd-1", - "type": "controlnet" - } - -You can filter by model type and base model as shown here: - - - controlnets = mgr.list_models(model_type=ModelType.ControlNet, - base_model=BaseModelType.StableDiffusion1) - for c in controlnets: - name = c['name'] - format = c['model_format'] - path = c['path'] - type = c['type'] - # etc - -ADDING AND REMOVING MODELS - -At startup time, the `models` directory will be scanned for -checkpoints, diffusers pipelines, controlnets, LoRAs and TI -embeddings. New entries will be added to the model manager and defunct -ones removed. Anything that is a main model (ModelType.Main) will be -added to models.yaml. For scanning to succeed, files need to be in -their proper places. For example, a controlnet folder built on the -stable diffusion 2 base, will need to be placed in -`models/sd-2/controlnet`. - -Layout of the `models` directory: - - models - ├── sd-1 - │ ├── controlnet - │ ├── lora - │ ├── main - │ └── embedding - ├── sd-2 - │ ├── controlnet - │ ├── lora - │ ├── main - │ └── embedding - └── core - ├── face_reconstruction - │ ├── codeformer - │ └── gfpgan - ├── sd-conversion - │ ├── clip-vit-large-patch14 - tokenizer, text_encoder subdirs - │ ├── stable-diffusion-2 - tokenizer, text_encoder subdirs - │ └── stable-diffusion-safety-checker - └── upscaling - └─── esrgan - - - -class ConfigMeta(BaseModel):Loras, textual_inversion and controlnet models are not listed -explicitly in models.yaml, but are added to the in-memory data -structure at initialization time by scanning the models directory. The -in-memory data structure can be resynchronized by calling -`manager.scan_models_directory()`. - -Files and folders placed inside the `autoimport` paths (paths -defined in `invokeai.yaml`) will also be scanned for new models at -initialization time and added to `models.yaml`. Files will not be -moved from this location but preserved in-place. These directories -are: - - configuration default description - ------------- ------- ----------- - autoimport_dir autoimport/main main models - lora_dir autoimport/lora LoRA/LyCORIS models - embedding_dir autoimport/embedding TI embeddings - controlnet_dir autoimport/controlnet ControlNet models - -In actuality, models located in any of these directories are scanned -to determine their type, so it isn't strictly necessary to organize -the different types in this way. This entry in `invokeai.yaml` will -recursively scan all subdirectories within `autoimport`, scan models -files it finds, and import them if recognized. - - Paths: - autoimport_dir: autoimport - -A model can be manually added using `add_model()` using the model's -name, base model, type and a dict of model attributes. See -`invokeai/backend/model_management/models` for the attributes required -by each model type. - -A model can be deleted using `del_model()`, providing the same -identifying information as `get_model()` - -The `heuristic_import()` method will take a set of strings -corresponding to local paths, remote URLs, and repo_ids, probe the -object to determine what type of model it is (if any), and import new -models into the manager. If passed a directory, it will recursively -scan it for models to import. The return value is a set of the models -successfully added. - -MODELS.YAML - -The general format of a models.yaml section is: - - type-of-model/name-of-model: - path: /path/to/local/file/or/directory - description: a description - format: diffusers|checkpoint - variant: normal|inpaint|depth - -The type of model is given in the stanza key, and is one of -{main, vae, lora, controlnet, textual} - -The format indicates whether the model is organized as a diffusers -folder with model subdirectories, or is contained in a single -checkpoint or safetensors file. - -The path points to a file or directory on disk. If a relative path, -the root is the InvokeAI ROOTDIR. - -""" -from __future__ import annotations - -import hashlib -import os -import textwrap -import types -from dataclasses import dataclass -from pathlib import Path -from shutil import move, rmtree -from typing import Callable, Dict, List, Literal, Optional, Set, Tuple, Union, cast - -import torch -import yaml -from omegaconf import OmegaConf -from omegaconf.dictconfig import DictConfig -from pydantic import BaseModel, ConfigDict, Field - -import invokeai.backend.util.logging as logger -from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.util import CUDA_DEVICE, Chdir - -from .model_cache import ModelCache, ModelLocker -from .model_search import ModelSearch -from .models import ( - MODEL_CLASSES, - BaseModelType, - DuplicateModelException, - InvalidModelException, - ModelBase, - ModelConfigBase, - ModelError, - ModelNotFoundException, - ModelType, - SchedulerPredictionType, - SubModelType, -) - -# We are only starting to number the config file with release 3. -# The config file version doesn't have to start at release version, but it will help -# reduce confusion. -CONFIG_FILE_VERSION = "3.0.0" - - -@dataclass -class ModelInfo: - context: ModelLocker - name: str - base_model: BaseModelType - type: ModelType - hash: str - location: Union[Path, str] - precision: torch.dtype - _cache: Optional[ModelCache] = None - - def __enter__(self): - return self.context.__enter__() - - def __exit__(self, *args, **kwargs): - self.context.__exit__(*args, **kwargs) - - -class AddModelResult(BaseModel): - name: str = Field(description="The name of the model after installation") - model_type: ModelType = Field(description="The type of model") - base_model: BaseModelType = Field(description="The base model") - config: ModelConfigBase = Field(description="The configuration of the model") - - model_config = ConfigDict(protected_namespaces=()) - - -MAX_CACHE_SIZE = 6.0 # GB - - -class ConfigMeta(BaseModel): - version: str - - -class ModelManager(object): - """ - High-level interface to model management. - """ - - logger: types.ModuleType = logger - - def __init__( - self, - config: Union[Path, DictConfig, str], - device_type: torch.device = CUDA_DEVICE, - precision: torch.dtype = torch.float16, - max_cache_size=MAX_CACHE_SIZE, - sequential_offload=False, - logger: types.ModuleType = logger, - ): - """ - Initialize with the path to the models.yaml config file. - Optional parameters are the torch device type, precision, max_models, - and sequential_offload boolean. Note that the default device - type and precision are set up for a CUDA system running at half precision. - """ - self.config_path = None - if isinstance(config, (str, Path)): - self.config_path = Path(config) - if not self.config_path.exists(): - logger.warning(f"The file {self.config_path} was not found. Initializing a new file") - self.initialize_model_config(self.config_path) - config = OmegaConf.load(self.config_path) - - elif not isinstance(config, DictConfig): - raise ValueError("config argument must be an OmegaConf object, a Path or a string") - - self.config_meta = ConfigMeta(**config.pop("__metadata__")) - # TODO: metadata not found - # TODO: version check - - self.app_config = InvokeAIAppConfig.get_config() - self.logger = logger - self.cache = ModelCache( - max_cache_size=max_cache_size, - max_vram_cache_size=self.app_config.vram_cache_size, - lazy_offloading=self.app_config.lazy_offload, - execution_device=device_type, - precision=precision, - sequential_offload=sequential_offload, - logger=logger, - log_memory_usage=self.app_config.log_memory_usage, - ) - - self._read_models(config) - - def _read_models(self, config: Optional[DictConfig] = None): - if not config: - if self.config_path: - config = OmegaConf.load(self.config_path) - else: - return - - self.models = {} - for model_key, model_config in config.items(): - if model_key.startswith("_"): - continue - model_name, base_model, model_type = self.parse_key(model_key) - model_class = self._get_implementation(base_model, model_type) - # alias for config file - model_config["model_format"] = model_config.pop("format") - self.models[model_key] = model_class.create_config(**model_config) - - # check config version number and update on disk/RAM if necessary - self.cache_keys = {} - - # add controlnet, lora and textual_inversion models from disk - self.scan_models_directory() - - def sync_to_config(self): - """ - Call this when `models.yaml` has been changed externally. - This will reinitialize internal data structures - """ - # Reread models directory; note that this will reinitialize the cache, - # causing otherwise unreferenced models to be removed from memory - self._read_models() - - def model_exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType, *, rescan=False) -> bool: - """ - Given a model name, returns True if it is a valid identifier. - - :param model_name: symbolic name of the model in models.yaml - :param model_type: ModelType enum indicating the type of model to return - :param base_model: BaseModelType enum indicating the base model used by this model - :param rescan: if True, scan_models_directory - """ - model_key = self.create_key(model_name, base_model, model_type) - exists = model_key in self.models - - # if model not found try to find it (maybe file just pasted) - if rescan and not exists: - self.scan_models_directory(base_model=base_model, model_type=model_type) - exists = self.model_exists(model_name, base_model, model_type, rescan=False) - - return exists - - @classmethod - def create_key( - cls, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - ) -> str: - # In 3.11, the behavior of (str,enum) when interpolated into a - # string has changed. The next two lines are defensive. - base_model = BaseModelType(base_model) - model_type = ModelType(model_type) - return f"{base_model.value}/{model_type.value}/{model_name}" - - @classmethod - def parse_key(cls, model_key: str) -> Tuple[str, BaseModelType, ModelType]: - base_model_str, model_type_str, model_name = model_key.split("/", 2) - try: - model_type = ModelType(model_type_str) - except Exception: - raise Exception(f"Unknown model type: {model_type_str}") - - try: - base_model = BaseModelType(base_model_str) - except Exception: - raise Exception(f"Unknown base model: {base_model_str}") - - return (model_name, base_model, model_type) - - def _get_model_cache_path(self, model_path): - return self.resolve_model_path(Path(".cache") / hashlib.md5(str(model_path).encode()).hexdigest()) - - @classmethod - def initialize_model_config(cls, config_path: Path): - """Create empty config file""" - with open(config_path, "w") as yaml_file: - yaml_file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}})) - - def get_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel_type: Optional[SubModelType] = None, - ) -> ModelInfo: - """Given a model named identified in models.yaml, return - an ModelInfo object describing it. - :param model_name: symbolic name of the model in models.yaml - :param model_type: ModelType enum indicating the type of model to return - :param base_model: BaseModelType enum indicating the base model used by this model - :param submodel_type: an ModelType enum indicating the portion of - the model to retrieve (e.g. ModelType.Vae) - """ - model_key = self.create_key(model_name, base_model, model_type) - - if not self.model_exists(model_name, base_model, model_type, rescan=True): - raise ModelNotFoundException(f"Model not found - {model_key}") - - model_config = self._get_model_config(base_model, model_name, model_type) - - model_path, is_submodel_override = self._get_model_path(model_config, submodel_type) - - if is_submodel_override: - model_type = submodel_type - submodel_type = None - - model_class = self._get_implementation(base_model, model_type) - - if not model_path.exists(): - if model_class.save_to_config: - self.models[model_key].error = ModelError.NotFound - raise Exception(f'Files for model "{model_key}" not found at {model_path}') - - else: - self.models.pop(model_key, None) - raise ModelNotFoundException(f'Files for model "{model_key}" not found at {model_path}') - - # TODO: path - # TODO: is it accurate to use path as id - dst_convert_path = self._get_model_cache_path(model_path) - - model_path = model_class.convert_if_required( - base_model=base_model, - model_path=str(model_path), # TODO: refactor str/Path types logic - output_path=dst_convert_path, - config=model_config, - ) - - model_context = self.cache.get_model( - model_path=model_path, - model_class=model_class, - base_model=base_model, - model_type=model_type, - submodel=submodel_type, - ) - - if model_key not in self.cache_keys: - self.cache_keys[model_key] = set() - self.cache_keys[model_key].add(model_context.key) - - model_hash = "" # TODO: - - return ModelInfo( - context=model_context, - name=model_name, - base_model=base_model, - type=submodel_type or model_type, - hash=model_hash, - location=model_path, # TODO: - precision=self.cache.precision, - _cache=self.cache, - ) - - def _get_model_path( - self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None - ) -> (Path, bool): - """Extract a model's filesystem path from its config. - - :return: The fully qualified Path of the module (or submodule). - """ - model_path = model_config.path - is_submodel_override = False - - # Does the config explicitly override the submodel? - if submodel_type is not None and hasattr(model_config, submodel_type): - submodel_path = getattr(model_config, submodel_type) - if submodel_path is not None and len(submodel_path) > 0: - model_path = getattr(model_config, submodel_type) - is_submodel_override = True - - model_path = self.resolve_model_path(model_path) - return model_path, is_submodel_override - - def _get_model_config(self, base_model: BaseModelType, model_name: str, model_type: ModelType) -> ModelConfigBase: - """Get a model's config object.""" - model_key = self.create_key(model_name, base_model, model_type) - try: - model_config = self.models[model_key] - except KeyError: - raise ModelNotFoundException(f"Model not found - {model_key}") - return model_config - - def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]: - """Get the concrete implementation class for a specific model type.""" - model_class = MODEL_CLASSES[base_model][model_type] - return model_class - - def _instantiate( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel_type: Optional[SubModelType] = None, - ) -> ModelBase: - """Make a new instance of this model, without loading it.""" - model_config = self._get_model_config(base_model, model_name, model_type) - model_path, is_submodel_override = self._get_model_path(model_config, submodel_type) - # FIXME: do non-overriden submodels get the right class? - constructor = self._get_implementation(base_model, model_type) - instance = constructor(model_path, base_model, model_type) - return instance - - def model_info( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - ) -> Union[dict, None]: - """ - Given a model name returns the OmegaConf (dict-like) object describing it. - """ - model_key = self.create_key(model_name, base_model, model_type) - if model_key in self.models: - return self.models[model_key].model_dump(exclude_defaults=True) - else: - return None # TODO: None or empty dict on not found - - def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]: - """ - Return a list of (str, BaseModelType, ModelType) corresponding to all models - known to the configuration. - """ - return [(self.parse_key(x)) for x in self.models.keys()] - - def list_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - ) -> Union[dict, None]: - """ - Returns a dict describing one installed model, using - the combined format of the list_models() method. - """ - models = self.list_models(base_model, model_type, model_name) - if len(models) >= 1: - return models[0] - else: - return None - - def list_models( - self, - base_model: Optional[BaseModelType] = None, - model_type: Optional[ModelType] = None, - model_name: Optional[str] = None, - ) -> list[dict]: - """ - Return a list of models. - """ - - model_keys = ( - [self.create_key(model_name, base_model, model_type)] - if model_name and base_model and model_type - else sorted(self.models, key=str.casefold) - ) - models = [] - for model_key in model_keys: - model_config = self.models.get(model_key) - if not model_config: - self.logger.error(f"Unknown model {model_name}") - raise ModelNotFoundException(f"Unknown model {model_name}") - - cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key) - if base_model is not None and cur_base_model != base_model: - continue - if model_type is not None and cur_model_type != model_type: - continue - - model_dict = dict( - **model_config.model_dump(exclude_defaults=True), - # OpenAPIModelInfoBase - model_name=cur_model_name, - base_model=cur_base_model, - model_type=cur_model_type, - ) - - # expose paths as absolute to help web UI - if path := model_dict.get("path"): - model_dict["path"] = str(self.resolve_model_path(path)) - models.append(model_dict) - - return models - - def print_models(self) -> None: - """ - Print a table of models and their descriptions. This needs to be redone - """ - # TODO: redo - for model_dict in self.list_models(): - for _model_name, model_info in model_dict.items(): - line = f'{model_info["name"]:25s} {model_info["type"]:10s} {model_info["description"]}' - print(line) - - # Tested - LS - def del_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - ): - """ - Delete the named model. - """ - model_key = self.create_key(model_name, base_model, model_type) - model_cfg = self.models.pop(model_key, None) - - if model_cfg is None: - raise ModelNotFoundException(f"Unknown model {model_key}") - - # note: it not garantie to release memory(model can has other references) - cache_ids = self.cache_keys.pop(model_key, []) - for cache_id in cache_ids: - self.cache.uncache_model(cache_id) - - # if model inside invoke models folder - delete files - model_path = self.resolve_model_path(model_cfg.path) - cache_path = self._get_model_cache_path(model_path) - if cache_path.exists(): - rmtree(str(cache_path)) - - if model_path.is_relative_to(self.app_config.models_path): - if model_path.is_dir(): - rmtree(str(model_path)) - else: - model_path.unlink() - self.commit() - - # LS: tested - def add_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - model_attributes: dict, - clobber: bool = False, - ) -> AddModelResult: - """ - Update the named model with a dictionary of attributes. Will fail with an - assertion error if the name already exists. Pass clobber=True to overwrite. - On a successful update, the config will be changed in memory and the - method will return True. Will fail with an assertion error if provided - attributes are incorrect or the model name is missing. - - The returned dict has the same format as the dict returned by - model_info(). - """ - # relativize paths as they go in - this makes it easier to move the models directory around - if path := model_attributes.get("path"): - model_attributes["path"] = str(self.relative_model_path(Path(path))) - - model_class = self._get_implementation(base_model, model_type) - model_config = model_class.create_config(**model_attributes) - model_key = self.create_key(model_name, base_model, model_type) - - if model_key in self.models and not clobber: - raise Exception(f'Attempt to overwrite existing model definition "{model_key}"') - - old_model = self.models.pop(model_key, None) - if old_model is not None: - # TODO: if path changed and old_model.path inside models folder should we delete this too? - - # remove conversion cache as config changed - old_model_path = self.resolve_model_path(old_model.path) - old_model_cache = self._get_model_cache_path(old_model_path) - if old_model_cache.exists(): - if old_model_cache.is_dir(): - rmtree(str(old_model_cache)) - else: - old_model_cache.unlink() - - # remove in-memory cache - # note: it not guaranteed to release memory(model can has other references) - cache_ids = self.cache_keys.pop(model_key, []) - for cache_id in cache_ids: - self.cache.uncache_model(cache_id) - - self.models[model_key] = model_config - self.commit() - - return AddModelResult( - name=model_name, - model_type=model_type, - base_model=base_model, - config=model_config, - ) - - def rename_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - new_name: Optional[str] = None, - new_base: Optional[BaseModelType] = None, - ) -> None: - """ - Rename or rebase a model. - """ - if new_name is None and new_base is None: - self.logger.error("rename_model() called with neither a new_name nor a new_base. {model_name} unchanged.") - return - - model_key = self.create_key(model_name, base_model, model_type) - model_cfg = self.models.get(model_key, None) - if not model_cfg: - raise ModelNotFoundException(f"Unknown model: {model_key}") - - old_path = self.resolve_model_path(model_cfg.path) - new_name = new_name or model_name - new_base = new_base or base_model - new_key = self.create_key(new_name, new_base, model_type) - if new_key in self.models: - raise ValueError(f'Attempt to overwrite existing model definition "{new_key}"') - - # if this is a model file/directory that we manage ourselves, we need to move it - if old_path.is_relative_to(self.app_config.models_path): - # keep the suffix! - if old_path.is_file(): - new_name = Path(new_name).with_suffix(old_path.suffix).as_posix() - new_path = self.resolve_model_path( - Path( - BaseModelType(new_base).value, - ModelType(model_type).value, - new_name, - ) - ) - move(old_path, new_path) - model_cfg.path = str(new_path.relative_to(self.app_config.models_path)) - - # clean up caches - old_model_cache = self._get_model_cache_path(old_path) - if old_model_cache.exists(): - if old_model_cache.is_dir(): - rmtree(str(old_model_cache)) - else: - old_model_cache.unlink() - - cache_ids = self.cache_keys.pop(model_key, []) - for cache_id in cache_ids: - self.cache.uncache_model(cache_id) - - self.models.pop(model_key, None) # delete - self.models[new_key] = model_cfg - self.commit() - - def convert_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: Literal[ModelType.Main, ModelType.Vae], - dest_directory: Optional[Path] = None, - ) -> AddModelResult: - """ - Convert a checkpoint file into a diffusers folder, deleting the cached - version and deleting the original checkpoint file if it is in the models - directory. - :param model_name: Name of the model to convert - :param base_model: Base model type - :param model_type: Type of model ['vae' or 'main'] - - This will raise a ValueError unless the model is a checkpoint. - """ - info = self.model_info(model_name, base_model, model_type) - - if info is None: - raise FileNotFoundError(f"model not found: {model_name}") - - if info["model_format"] != "checkpoint": - raise ValueError(f"not a checkpoint format model: {model_name}") - - # We are taking advantage of a side effect of get_model() that converts check points - # into cached diffusers directories stored at `location`. It doesn't matter - # what submodeltype we request here, so we get the smallest. - submodel = {"submodel_type": SubModelType.Scheduler} if model_type == ModelType.Main else {} - model = self.get_model( - model_name, - base_model, - model_type, - **submodel, - ) - checkpoint_path = self.resolve_model_path(info["path"]) - old_diffusers_path = self.resolve_model_path(model.location) - new_diffusers_path = ( - dest_directory or self.app_config.models_path / base_model.value / model_type.value - ) / model_name - if new_diffusers_path.exists(): - raise ValueError(f"A diffusers model already exists at {new_diffusers_path}") - - try: - move(old_diffusers_path, new_diffusers_path) - info["model_format"] = "diffusers" - info["path"] = ( - str(new_diffusers_path) - if dest_directory - else str(new_diffusers_path.relative_to(self.app_config.models_path)) - ) - info.pop("config") - - result = self.add_model(model_name, base_model, model_type, model_attributes=info, clobber=True) - except Exception: - # something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error! - rmtree(new_diffusers_path) - raise - - if checkpoint_path.exists() and checkpoint_path.is_relative_to(self.app_config.models_path): - checkpoint_path.unlink() - - return result - - def resolve_model_path(self, path: Union[Path, str]) -> Path: - """return relative paths based on configured models_path""" - return self.app_config.models_path / path - - def relative_model_path(self, model_path: Path) -> Path: - if model_path.is_relative_to(self.app_config.models_path): - model_path = model_path.relative_to(self.app_config.models_path) - return model_path - - def search_models(self, search_folder): - self.logger.info(f"Finding Models In: {search_folder}") - models_folder_ckpt = Path(search_folder).glob("**/*.ckpt") - models_folder_safetensors = Path(search_folder).glob("**/*.safetensors") - - ckpt_files = [x for x in models_folder_ckpt if x.is_file()] - safetensor_files = [x for x in models_folder_safetensors if x.is_file()] - - files = ckpt_files + safetensor_files - - found_models = [] - for file in files: - location = str(file.resolve()).replace("\\", "/") - if "model.safetensors" not in location and "diffusion_pytorch_model.safetensors" not in location: - found_models.append({"name": file.stem, "location": location}) - - return search_folder, found_models - - def commit(self, conf_file: Optional[Path] = None) -> None: - """ - Write current configuration out to the indicated file. - """ - data_to_save = {} - data_to_save["__metadata__"] = self.config_meta.model_dump() - - for model_key, model_config in self.models.items(): - model_name, base_model, model_type = self.parse_key(model_key) - model_class = self._get_implementation(base_model, model_type) - if model_class.save_to_config: - # TODO: or exclude_unset better fits here? - data_to_save[model_key] = cast(BaseModel, model_config).model_dump( - exclude_defaults=True, exclude={"error"}, mode="json" - ) - # alias for config file - data_to_save[model_key]["format"] = data_to_save[model_key].pop("model_format") - - yaml_str = OmegaConf.to_yaml(data_to_save) - config_file_path = conf_file or self.config_path - assert config_file_path is not None, "no config file path to write to" - config_file_path = self.app_config.root_path / config_file_path - tmpfile = os.path.join(os.path.dirname(config_file_path), "new_config.tmp") - try: - with open(tmpfile, "w", encoding="utf-8") as outfile: - outfile.write(self.preamble()) - outfile.write(yaml_str) - os.replace(tmpfile, config_file_path) - except OSError as err: - self.logger.warning(f"Could not modify the config file at {config_file_path}") - self.logger.warning(err) - - def preamble(self) -> str: - """ - Returns the preamble for the config file. - """ - return textwrap.dedent( - """ - # This file describes the alternative machine learning models - # available to InvokeAI script. - # - # To add a new model, follow the examples below. Each - # model requires a model config file, a weights file, - # and the width and height of the images it - # was trained on. - """ - ) - - def scan_models_directory( - self, - base_model: Optional[BaseModelType] = None, - model_type: Optional[ModelType] = None, - ): - loaded_files = set() - new_models_found = False - - self.logger.info(f"Scanning {self.app_config.models_path} for new models") - with Chdir(self.app_config.models_path): - for model_key, model_config in list(self.models.items()): - model_name, cur_base_model, cur_model_type = self.parse_key(model_key) - - # Patch for relative path bug in older models.yaml - paths should not - # be starting with a hard-coded 'models'. This will also fix up - # models.yaml when committed. - if model_config.path.startswith("models"): - model_config.path = str(Path(*Path(model_config.path).parts[1:])) - - model_path = self.resolve_model_path(model_config.path).absolute() - if not model_path.exists(): - model_class = self._get_implementation(cur_base_model, cur_model_type) - if model_class.save_to_config: - model_config.error = ModelError.NotFound - self.models.pop(model_key, None) - else: - self.models.pop(model_key, None) - else: - loaded_files.add(model_path) - - for cur_base_model in BaseModelType: - if base_model is not None and cur_base_model != base_model: - continue - - for cur_model_type in ModelType: - if model_type is not None and cur_model_type != model_type: - continue - model_class = self._get_implementation(cur_base_model, cur_model_type) - models_dir = self.resolve_model_path(Path(cur_base_model.value, cur_model_type.value)) - - if not models_dir.exists(): - continue # TODO: or create all folders? - - for model_path in models_dir.iterdir(): - if model_path not in loaded_files: # TODO: check - if model_path.name.startswith("."): - continue - model_name = model_path.name if model_path.is_dir() else model_path.stem - model_key = self.create_key(model_name, cur_base_model, cur_model_type) - - try: - if model_key in self.models: - raise DuplicateModelException(f"Model with key {model_key} added twice") - - model_path = self.relative_model_path(model_path) - model_config: ModelConfigBase = model_class.probe_config( - str(model_path), model_base=cur_base_model - ) - self.models[model_key] = model_config - new_models_found = True - except DuplicateModelException as e: - self.logger.warning(e) - except InvalidModelException as e: - self.logger.warning(f"Not a valid model: {model_path}. {e}") - except NotImplementedError as e: - self.logger.warning(e) - except Exception as e: - self.logger.warning(f"Error loading model {model_path}. {e}") - - imported_models = self.scan_autoimport_directory() - if (new_models_found or imported_models) and self.config_path: - self.commit() - - def scan_autoimport_directory(self) -> Dict[str, AddModelResult]: - """ - Scan the autoimport directory (if defined) and import new models, delete defunct models. - """ - # avoid circular import - from invokeai.backend.install.model_install_backend import ModelInstall - from invokeai.frontend.install.model_install import ask_user_for_prediction_type - - class ScanAndImport(ModelSearch): - def __init__(self, directories, logger, ignore: Set[Path], installer: ModelInstall): - super().__init__(directories, logger) - self.installer = installer - self.ignore = ignore - - def on_search_started(self): - self.new_models_found = {} - - def on_model_found(self, model: Path): - if model not in self.ignore: - self.new_models_found.update(self.installer.heuristic_import(model)) - - def on_search_completed(self): - self.logger.info( - f"Scanned {self._items_scanned} files and directories, imported {len(self.new_models_found)} models" - ) - - def models_found(self): - return self.new_models_found - - config = self.app_config - - # LS: hacky - # Patch in the SD VAE from core so that it is available for use by the UI - try: - self.heuristic_import({str(self.resolve_model_path("core/convert/sd-vae-ft-mse"))}) - except Exception: - pass - - installer = ModelInstall( - config=self.app_config, - model_manager=self, - prediction_type_helper=ask_user_for_prediction_type, - ) - known_paths = {self.resolve_model_path(x["path"]) for x in self.list_models()} - directories = { - config.root_path / x - for x in [ - config.autoimport_dir, - config.lora_dir, - config.embedding_dir, - config.controlnet_dir, - ] - if x - } - scanner = ScanAndImport(directories, self.logger, ignore=known_paths, installer=installer) - scanner.search() - - return scanner.models_found() - - def heuristic_import( - self, - items_to_import: Set[str], - prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None, - ) -> Dict[str, AddModelResult]: - """Import a list of paths, repo_ids or URLs. Returns the set of - successfully imported items. - :param items_to_import: Set of strings corresponding to models to be imported. - :param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType. - - The prediction type helper is necessary to distinguish between - models based on Stable Diffusion 2 Base (requiring - SchedulerPredictionType.Epsilson) and Stable Diffusion 768 - (requiring SchedulerPredictionType.VPrediction). It is - generally impossible to do this programmatically, so the - prediction_type_helper usually asks the user to choose. - - The result is a set of successfully installed models. Each element - of the set is a dict corresponding to the newly-created OmegaConf stanza for - that model. - - May return the following exceptions: - - ModelNotFoundException - one or more of the items to import is not a valid path, repo_id or URL - - ValueError - a corresponding model already exists - """ - # avoid circular import here - from invokeai.backend.install.model_install_backend import ModelInstall - - successfully_installed = {} - - installer = ModelInstall( - config=self.app_config, prediction_type_helper=prediction_type_helper, model_manager=self - ) - for thing in items_to_import: - installed = installer.heuristic_import(thing) - successfully_installed.update(installed) - self.commit() - return successfully_installed diff --git a/invokeai/backend/model_management/model_merge.py b/invokeai/backend/model_management/model_merge.py deleted file mode 100644 index a9f0a23618e..00000000000 --- a/invokeai/backend/model_management/model_merge.py +++ /dev/null @@ -1,140 +0,0 @@ -""" -invokeai.backend.model_management.model_merge exports: -merge_diffusion_models() -- combine multiple models by location and return a pipeline object -merge_diffusion_models_and_commit() -- combine multiple models by ModelManager ID and write to models.yaml - -Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team -""" - -import warnings -from enum import Enum -from pathlib import Path -from typing import List, Optional, Union - -from diffusers import DiffusionPipeline -from diffusers import logging as dlogging - -import invokeai.backend.util.logging as logger - -from ...backend.model_management import AddModelResult, BaseModelType, ModelManager, ModelType, ModelVariantType - - -class MergeInterpolationMethod(str, Enum): - WeightedSum = "weighted_sum" - Sigmoid = "sigmoid" - InvSigmoid = "inv_sigmoid" - AddDifference = "add_difference" - - -class ModelMerger(object): - def __init__(self, manager: ModelManager): - self.manager = manager - - def merge_diffusion_models( - self, - model_paths: List[Path], - alpha: float = 0.5, - interp: Optional[MergeInterpolationMethod] = None, - force: bool = False, - **kwargs, - ) -> DiffusionPipeline: - """ - :param model_paths: up to three models, designated by their local paths or HuggingFace repo_ids - :param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha - would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2 - :param interp: The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None. - Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. - :param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False. - - **kwargs - the default DiffusionPipeline.get_config_dict kwargs: - cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map - """ - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - verbosity = dlogging.get_verbosity() - dlogging.set_verbosity_error() - - pipe = DiffusionPipeline.from_pretrained( - model_paths[0], - custom_pipeline="checkpoint_merger", - ) - merged_pipe = pipe.merge( - pretrained_model_name_or_path_list=model_paths, - alpha=alpha, - interp=interp.value if interp else None, # diffusers API treats None as "weighted sum" - force=force, - **kwargs, - ) - dlogging.set_verbosity(verbosity) - return merged_pipe - - def merge_diffusion_models_and_save( - self, - model_names: List[str], - base_model: Union[BaseModelType, str], - merged_model_name: str, - alpha: float = 0.5, - interp: Optional[MergeInterpolationMethod] = None, - force: bool = False, - merge_dest_directory: Optional[Path] = None, - **kwargs, - ) -> AddModelResult: - """ - :param models: up to three models, designated by their InvokeAI models.yaml model name - :param base_model: base model (must be the same for all merged models!) - :param merged_model_name: name for new model - :param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha - would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2 - :param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None. - Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C). - :param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False. - :param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended) - **kwargs - the default DiffusionPipeline.get_config_dict kwargs: - cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map - """ - model_paths = [] - config = self.manager.app_config - base_model = BaseModelType(base_model) - vae = None - - for mod in model_names: - info = self.manager.list_model(mod, base_model=base_model, model_type=ModelType.Main) - assert info, f"model {mod}, base_model {base_model}, is unknown" - assert ( - info["model_format"] == "diffusers" - ), f"{mod} is not a diffusers model. It must be optimized before merging" - assert info["variant"] == "normal", f"{mod} is a {info['variant']} model, which cannot currently be merged" - assert ( - len(model_names) <= 2 or interp == MergeInterpolationMethod.AddDifference - ), "When merging three models, only the 'add_difference' merge method is supported" - # pick up the first model's vae - if mod == model_names[0]: - vae = info.get("vae") - model_paths.extend([(config.root_path / info["path"]).as_posix()]) - - merge_method = None if interp == "weighted_sum" else MergeInterpolationMethod(interp) - logger.debug(f"interp = {interp}, merge_method={merge_method}") - merged_pipe = self.merge_diffusion_models(model_paths, alpha, merge_method, force, **kwargs) - dump_path = ( - Path(merge_dest_directory) - if merge_dest_directory - else config.models_path / base_model.value / ModelType.Main.value - ) - dump_path.mkdir(parents=True, exist_ok=True) - dump_path = (dump_path / merged_model_name).as_posix() - - merged_pipe.save_pretrained(dump_path, safe_serialization=True) - attributes = { - "path": dump_path, - "description": f"Merge of models {', '.join(model_names)}", - "model_format": "diffusers", - "variant": ModelVariantType.Normal.value, - "vae": vae, - } - return self.manager.add_model( - merged_model_name, - base_model=base_model, - model_type=ModelType.Main, - model_attributes=attributes, - clobber=True, - ) diff --git a/invokeai/backend/model_management/model_probe.py b/invokeai/backend/model_management/model_probe.py deleted file mode 100644 index 74b1b72d317..00000000000 --- a/invokeai/backend/model_management/model_probe.py +++ /dev/null @@ -1,664 +0,0 @@ -import json -import re -from dataclasses import dataclass -from pathlib import Path -from typing import Callable, Dict, Literal, Optional, Union - -import safetensors.torch -import torch -from diffusers import ConfigMixin, ModelMixin -from picklescan.scanner import scan_file_path - -from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat - -from .models import ( - BaseModelType, - InvalidModelException, - ModelType, - ModelVariantType, - SchedulerPredictionType, - SilenceWarnings, -) -from .models.base import read_checkpoint_meta -from .util import lora_token_vector_length - - -@dataclass -class ModelProbeInfo(object): - model_type: ModelType - base_type: BaseModelType - variant_type: ModelVariantType - prediction_type: SchedulerPredictionType - upcast_attention: bool - format: Literal["diffusers", "checkpoint", "lycoris", "olive", "onnx"] - image_size: int - name: Optional[str] = None - description: Optional[str] = None - - -class ProbeBase(object): - """forward declaration""" - - pass - - -class ModelProbe(object): - PROBES = { - "diffusers": {}, - "checkpoint": {}, - "onnx": {}, - } - - CLASS2TYPE = { - "StableDiffusionPipeline": ModelType.Main, - "StableDiffusionInpaintPipeline": ModelType.Main, - "StableDiffusionXLPipeline": ModelType.Main, - "StableDiffusionXLImg2ImgPipeline": ModelType.Main, - "StableDiffusionXLInpaintPipeline": ModelType.Main, - "LatentConsistencyModelPipeline": ModelType.Main, - "AutoencoderKL": ModelType.Vae, - "AutoencoderTiny": ModelType.Vae, - "ControlNetModel": ModelType.ControlNet, - "CLIPVisionModelWithProjection": ModelType.CLIPVision, - "T2IAdapter": ModelType.T2IAdapter, - } - - @classmethod - def register_probe( - cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: ProbeBase - ): - cls.PROBES[format][model_type] = probe_class - - @classmethod - def heuristic_probe( - cls, - model: Union[Dict, ModelMixin, Path], - prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None, - ) -> ModelProbeInfo: - if isinstance(model, Path): - return cls.probe(model_path=model, prediction_type_helper=prediction_type_helper) - elif isinstance(model, (dict, ModelMixin, ConfigMixin)): - return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper) - else: - raise InvalidModelException("model parameter {model} is neither a Path, nor a model") - - @classmethod - def probe( - cls, - model_path: Path, - model: Optional[Union[Dict, ModelMixin]] = None, - prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None, - ) -> ModelProbeInfo: - """ - Probe the model at model_path and return sufficient information about it - to place it somewhere in the models directory hierarchy. If the model is - already loaded into memory, you may provide it as model in order to avoid - opening it a second time. The prediction_type_helper callable is a function that receives - the path to the model and returns the SchedulerPredictionType. - """ - if model_path: - format_type = "diffusers" if model_path.is_dir() else "checkpoint" - else: - format_type = "diffusers" if isinstance(model, (ConfigMixin, ModelMixin)) else "checkpoint" - model_info = None - try: - model_type = ( - cls.get_model_type_from_folder(model_path, model) - if format_type == "diffusers" - else cls.get_model_type_from_checkpoint(model_path, model) - ) - format_type = "onnx" if model_type == ModelType.ONNX else format_type - probe_class = cls.PROBES[format_type].get(model_type) - if not probe_class: - return None - probe = probe_class(model_path, model, prediction_type_helper) - base_type = probe.get_base_type() - variant_type = probe.get_variant_type() - prediction_type = probe.get_scheduler_prediction_type() - name = cls.get_model_name(model_path) - description = f"{base_type.value} {model_type.value} model {name}" - format = probe.get_format() - model_info = ModelProbeInfo( - model_type=model_type, - base_type=base_type, - variant_type=variant_type, - prediction_type=prediction_type, - name=name, - description=description, - upcast_attention=( - base_type == BaseModelType.StableDiffusion2 - and prediction_type == SchedulerPredictionType.VPrediction - ), - format=format, - image_size=( - 1024 - if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner}) - else ( - 768 - if ( - base_type == BaseModelType.StableDiffusion2 - and prediction_type == SchedulerPredictionType.VPrediction - ) - else 512 - ) - ), - ) - except Exception: - raise - - return model_info - - @classmethod - def get_model_name(cls, model_path: Path) -> str: - if model_path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}: - return model_path.stem - else: - return model_path.name - - @classmethod - def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict) -> ModelType: - if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"): - return None - - if model_path.name == "learned_embeds.bin": - return ModelType.TextualInversion - - ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True) - ckpt = ckpt.get("state_dict", ckpt) - - for key in ckpt.keys(): - if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}): - return ModelType.Main - elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}): - return ModelType.Vae - elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}): - return ModelType.Lora - elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}): - return ModelType.Lora - elif any(key.startswith(v) for v in {"control_model", "input_blocks"}): - return ModelType.ControlNet - elif key in {"emb_params", "string_to_param"}: - return ModelType.TextualInversion - - else: - # diffusers-ti - if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()): - return ModelType.TextualInversion - - raise InvalidModelException(f"Unable to determine model type for {model_path}") - - @classmethod - def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin) -> ModelType: - """ - Get the model type of a hugging-face style folder. - """ - class_name = None - error_hint = None - if model: - class_name = model.__class__.__name__ - else: - for suffix in ["bin", "safetensors"]: - if (folder_path / f"learned_embeds.{suffix}").exists(): - return ModelType.TextualInversion - if (folder_path / f"pytorch_lora_weights.{suffix}").exists(): - return ModelType.Lora - if (folder_path / "unet/model.onnx").exists(): - return ModelType.ONNX - if (folder_path / "image_encoder.txt").exists(): - return ModelType.IPAdapter - - i = folder_path / "model_index.json" - c = folder_path / "config.json" - config_path = i if i.exists() else c if c.exists() else None - - if config_path: - with open(config_path, "r") as file: - conf = json.load(file) - if "_class_name" in conf: - class_name = conf["_class_name"] - elif "architectures" in conf: - class_name = conf["architectures"][0] - else: - class_name = None - else: - error_hint = f"No model_index.json or config.json found in {folder_path}." - - if class_name and (type := cls.CLASS2TYPE.get(class_name)): - return type - else: - error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]" - - # give up - raise InvalidModelException( - f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "") - ) - - @classmethod - def _scan_and_load_checkpoint(cls, model_path: Path) -> dict: - with SilenceWarnings(): - if model_path.suffix.endswith((".ckpt", ".pt", ".bin")): - cls._scan_model(model_path, model_path) - return torch.load(model_path, map_location="cpu") - else: - return safetensors.torch.load_file(model_path) - - @classmethod - def _scan_model(cls, model_name, checkpoint): - """ - Apply picklescanner to the indicated checkpoint and issue a warning - and option to exit if an infected file is identified. - """ - # scan model - scan_result = scan_file_path(checkpoint) - if scan_result.infected_files != 0: - raise Exception("The model {model_name} is potentially infected by malware. Aborting import.") - - -# ##################################################3 -# Checkpoint probing -# ##################################################3 -class ProbeBase(object): - def get_base_type(self) -> BaseModelType: - pass - - def get_variant_type(self) -> ModelVariantType: - pass - - def get_scheduler_prediction_type(self) -> SchedulerPredictionType: - pass - - def get_format(self) -> str: - pass - - -class CheckpointProbeBase(ProbeBase): - def __init__( - self, checkpoint_path: Path, checkpoint: dict, helper: Callable[[Path], SchedulerPredictionType] = None - ) -> BaseModelType: - self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path) - self.checkpoint_path = checkpoint_path - self.helper = helper - - def get_base_type(self) -> BaseModelType: - pass - - def get_format(self) -> str: - return "checkpoint" - - def get_variant_type(self) -> ModelVariantType: - model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path, self.checkpoint) - if model_type != ModelType.Main: - return ModelVariantType.Normal - state_dict = self.checkpoint.get("state_dict") or self.checkpoint - in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1] - if in_channels == 9: - return ModelVariantType.Inpaint - elif in_channels == 5: - return ModelVariantType.Depth - elif in_channels == 4: - return ModelVariantType.Normal - else: - raise InvalidModelException( - f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}" - ) - - -class PipelineCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - checkpoint = self.checkpoint - state_dict = self.checkpoint.get("state_dict") or checkpoint - key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" - if key_name in state_dict and state_dict[key_name].shape[-1] == 768: - return BaseModelType.StableDiffusion1 - if key_name in state_dict and state_dict[key_name].shape[-1] == 1024: - return BaseModelType.StableDiffusion2 - key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight" - if key_name in state_dict and state_dict[key_name].shape[-1] == 2048: - return BaseModelType.StableDiffusionXL - elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280: - return BaseModelType.StableDiffusionXLRefiner - else: - raise InvalidModelException("Cannot determine base type") - - def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]: - """Return model prediction type.""" - # if there is a .yaml associated with this checkpoint, then we do not need - # to probe for the prediction type as it will be ignored. - if self.checkpoint_path and self.checkpoint_path.with_suffix(".yaml").exists(): - return None - - type = self.get_base_type() - if type == BaseModelType.StableDiffusion2: - checkpoint = self.checkpoint - state_dict = self.checkpoint.get("state_dict") or checkpoint - key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" - if key_name in state_dict and state_dict[key_name].shape[-1] == 1024: - if "global_step" in checkpoint: - if checkpoint["global_step"] == 220000: - return SchedulerPredictionType.Epsilon - elif checkpoint["global_step"] == 110000: - return SchedulerPredictionType.VPrediction - if self.helper and self.checkpoint_path: - if helper_guess := self.helper(self.checkpoint_path): - return helper_guess - return SchedulerPredictionType.VPrediction # a guess for sd2 ckpts - - elif type == BaseModelType.StableDiffusion1: - if self.helper and self.checkpoint_path: - if helper_guess := self.helper(self.checkpoint_path): - return helper_guess - return SchedulerPredictionType.Epsilon # a reasonable guess for sd1 ckpts - else: - return None - - -class VaeCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - # I can't find any standalone 2.X VAEs to test with! - return BaseModelType.StableDiffusion1 - - -class LoRACheckpointProbe(CheckpointProbeBase): - def get_format(self) -> str: - return "lycoris" - - def get_base_type(self) -> BaseModelType: - checkpoint = self.checkpoint - token_vector_length = lora_token_vector_length(checkpoint) - - if token_vector_length == 768: - return BaseModelType.StableDiffusion1 - elif token_vector_length == 1024: - return BaseModelType.StableDiffusion2 - elif token_vector_length == 1280: - return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641 - elif token_vector_length == 2048: - return BaseModelType.StableDiffusionXL - else: - raise InvalidModelException(f"Unknown LoRA type: {self.checkpoint_path}") - - -class TextualInversionCheckpointProbe(CheckpointProbeBase): - def get_format(self) -> str: - return None - - def get_base_type(self) -> BaseModelType: - checkpoint = self.checkpoint - if "string_to_token" in checkpoint: - token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1] - elif "emb_params" in checkpoint: - token_dim = checkpoint["emb_params"].shape[-1] - elif "clip_g" in checkpoint: - token_dim = checkpoint["clip_g"].shape[-1] - else: - token_dim = list(checkpoint.values())[0].shape[-1] - if token_dim == 768: - return BaseModelType.StableDiffusion1 - elif token_dim == 1024: - return BaseModelType.StableDiffusion2 - elif token_dim == 1280: - return BaseModelType.StableDiffusionXL - else: - return None - - -class ControlNetCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - checkpoint = self.checkpoint - for key_name in ( - "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", - "input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", - ): - if key_name not in checkpoint: - continue - if checkpoint[key_name].shape[-1] == 768: - return BaseModelType.StableDiffusion1 - elif checkpoint[key_name].shape[-1] == 1024: - return BaseModelType.StableDiffusion2 - elif self.checkpoint_path and self.helper: - return self.helper(self.checkpoint_path) - raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}") - - -class IPAdapterCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - raise NotImplementedError() - - -class CLIPVisionCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - raise NotImplementedError() - - -class T2IAdapterCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - raise NotImplementedError() - - -######################################################## -# classes for probing folders -####################################################### -class FolderProbeBase(ProbeBase): - def __init__(self, folder_path: Path, model: ModelMixin = None, helper: Callable = None): # not used - self.model = model - self.folder_path = folder_path - - def get_variant_type(self) -> ModelVariantType: - return ModelVariantType.Normal - - def get_format(self) -> str: - return "diffusers" - - -class PipelineFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - if self.model: - unet_conf = self.model.unet.config - else: - with open(self.folder_path / "unet" / "config.json", "r") as file: - unet_conf = json.load(file) - if unet_conf["cross_attention_dim"] == 768: - return BaseModelType.StableDiffusion1 - elif unet_conf["cross_attention_dim"] == 1024: - return BaseModelType.StableDiffusion2 - elif unet_conf["cross_attention_dim"] == 1280: - return BaseModelType.StableDiffusionXLRefiner - elif unet_conf["cross_attention_dim"] == 2048: - return BaseModelType.StableDiffusionXL - else: - raise InvalidModelException(f"Unknown base model for {self.folder_path}") - - def get_scheduler_prediction_type(self) -> SchedulerPredictionType: - if self.model: - scheduler_conf = self.model.scheduler.config - else: - with open(self.folder_path / "scheduler" / "scheduler_config.json", "r") as file: - scheduler_conf = json.load(file) - if scheduler_conf["prediction_type"] == "v_prediction": - return SchedulerPredictionType.VPrediction - elif scheduler_conf["prediction_type"] == "epsilon": - return SchedulerPredictionType.Epsilon - else: - return None - - def get_variant_type(self) -> ModelVariantType: - # This only works for pipelines! Any kind of - # exception results in our returning the - # "normal" variant type - try: - if self.model: - conf = self.model.unet.config - else: - config_file = self.folder_path / "unet" / "config.json" - with open(config_file, "r") as file: - conf = json.load(file) - - in_channels = conf["in_channels"] - if in_channels == 9: - return ModelVariantType.Inpaint - elif in_channels == 5: - return ModelVariantType.Depth - elif in_channels == 4: - return ModelVariantType.Normal - except Exception: - pass - return ModelVariantType.Normal - - -class VaeFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - if self._config_looks_like_sdxl(): - return BaseModelType.StableDiffusionXL - elif self._name_looks_like_sdxl(): - # but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down - # by a factor of 8), we can't necessarily tell them apart by config hyperparameters. - return BaseModelType.StableDiffusionXL - else: - return BaseModelType.StableDiffusion1 - - def _config_looks_like_sdxl(self) -> bool: - # config values that distinguish Stability's SD 1.x VAE from their SDXL VAE. - config_file = self.folder_path / "config.json" - if not config_file.exists(): - raise InvalidModelException(f"Cannot determine base type for {self.folder_path}") - with open(config_file, "r") as file: - config = json.load(file) - return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024] - - def _name_looks_like_sdxl(self) -> bool: - return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE)) - - def _guess_name(self) -> str: - name = self.folder_path.name - if name == "vae": - name = self.folder_path.parent.name - return name - - -class TextualInversionFolderProbe(FolderProbeBase): - def get_format(self) -> str: - return None - - def get_base_type(self) -> BaseModelType: - path = self.folder_path / "learned_embeds.bin" - if not path.exists(): - return None - checkpoint = ModelProbe._scan_and_load_checkpoint(path) - return TextualInversionCheckpointProbe(None, checkpoint=checkpoint).get_base_type() - - -class ONNXFolderProbe(FolderProbeBase): - def get_format(self) -> str: - return "onnx" - - def get_base_type(self) -> BaseModelType: - return BaseModelType.StableDiffusion1 - - def get_variant_type(self) -> ModelVariantType: - return ModelVariantType.Normal - - -class ControlNetFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - config_file = self.folder_path / "config.json" - if not config_file.exists(): - raise InvalidModelException(f"Cannot determine base type for {self.folder_path}") - with open(config_file, "r") as file: - config = json.load(file) - # no obvious way to distinguish between sd2-base and sd2-768 - dimension = config["cross_attention_dim"] - base_model = ( - BaseModelType.StableDiffusion1 - if dimension == 768 - else ( - BaseModelType.StableDiffusion2 - if dimension == 1024 - else BaseModelType.StableDiffusionXL - if dimension == 2048 - else None - ) - ) - if not base_model: - raise InvalidModelException(f"Unable to determine model base for {self.folder_path}") - return base_model - - -class LoRAFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - model_file = None - for suffix in ["safetensors", "bin"]: - base_file = self.folder_path / f"pytorch_lora_weights.{suffix}" - if base_file.exists(): - model_file = base_file - break - if not model_file: - raise InvalidModelException("Unknown LoRA format encountered") - return LoRACheckpointProbe(model_file, None).get_base_type() - - -class IPAdapterFolderProbe(FolderProbeBase): - def get_format(self) -> str: - return IPAdapterModelFormat.InvokeAI.value - - def get_base_type(self) -> BaseModelType: - model_file = self.folder_path / "ip_adapter.bin" - if not model_file.exists(): - raise InvalidModelException("Unknown IP-Adapter model format.") - - state_dict = torch.load(model_file, map_location="cpu") - cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1] - if cross_attention_dim == 768: - return BaseModelType.StableDiffusion1 - elif cross_attention_dim == 1024: - return BaseModelType.StableDiffusion2 - elif cross_attention_dim == 2048: - return BaseModelType.StableDiffusionXL - else: - raise InvalidModelException(f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}.") - - -class CLIPVisionFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - return BaseModelType.Any - - -class T2IAdapterFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - config_file = self.folder_path / "config.json" - if not config_file.exists(): - raise InvalidModelException(f"Cannot determine base type for {self.folder_path}") - with open(config_file, "r") as file: - config = json.load(file) - - adapter_type = config.get("adapter_type", None) - if adapter_type == "full_adapter_xl": - return BaseModelType.StableDiffusionXL - elif adapter_type == "full_adapter" or "light_adapter": - # I haven't seen any T2I adapter models for SD2, so assume that this is an SD1 adapter. - return BaseModelType.StableDiffusion1 - else: - raise InvalidModelException( - f"Unable to determine base model for '{self.folder_path}' (adapter_type = {adapter_type})." - ) - - -############## register probe classes ###### -ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe) - -ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe) - -ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe) diff --git a/invokeai/backend/model_management/model_search.py b/invokeai/backend/model_management/model_search.py deleted file mode 100644 index e125c3ced7f..00000000000 --- a/invokeai/backend/model_management/model_search.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright 2023, Lincoln D. Stein and the InvokeAI Team -""" -Abstract base class for recursive directory search for models. -""" - -import os -from abc import ABC, abstractmethod -from pathlib import Path -from typing import List, Set, types - -import invokeai.backend.util.logging as logger - - -class ModelSearch(ABC): - def __init__(self, directories: List[Path], logger: types.ModuleType = logger): - """ - Initialize a recursive model directory search. - :param directories: List of directory Paths to recurse through - :param logger: Logger to use - """ - self.directories = directories - self.logger = logger - self._items_scanned = 0 - self._models_found = 0 - self._scanned_dirs = set() - self._scanned_paths = set() - self._pruned_paths = set() - - @abstractmethod - def on_search_started(self): - """ - Called before the scan starts. - """ - pass - - @abstractmethod - def on_model_found(self, model: Path): - """ - Process a found model. Raise an exception if something goes wrong. - :param model: Model to process - could be a directory or checkpoint. - """ - pass - - @abstractmethod - def on_search_completed(self): - """ - Perform some activity when the scan is completed. May use instance - variables, items_scanned and models_found - """ - pass - - def search(self): - self.on_search_started() - for dir in self.directories: - self.walk_directory(dir) - self.on_search_completed() - - def walk_directory(self, path: Path): - for root, dirs, files in os.walk(path, followlinks=True): - if str(Path(root).name).startswith("."): - self._pruned_paths.add(root) - if any(Path(root).is_relative_to(x) for x in self._pruned_paths): - continue - - self._items_scanned += len(dirs) + len(files) - for d in dirs: - path = Path(root) / d - if path in self._scanned_paths or path.parent in self._scanned_dirs: - self._scanned_dirs.add(path) - continue - if any( - (path / x).exists() - for x in { - "config.json", - "model_index.json", - "learned_embeds.bin", - "pytorch_lora_weights.bin", - "image_encoder.txt", - } - ): - try: - self.on_model_found(path) - self._models_found += 1 - self._scanned_dirs.add(path) - except Exception as e: - self.logger.warning(f"Failed to process '{path}': {e}") - - for f in files: - path = Path(root) / f - if path.parent in self._scanned_dirs: - continue - if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}: - try: - self.on_model_found(path) - self._models_found += 1 - except Exception as e: - self.logger.warning(f"Failed to process '{path}': {e}") - - -class FindModels(ModelSearch): - def on_search_started(self): - self.models_found: Set[Path] = set() - - def on_model_found(self, model: Path): - self.models_found.add(model) - - def on_search_completed(self): - pass - - def list_models(self) -> List[Path]: - self.search() - return list(self.models_found) diff --git a/invokeai/backend/model_management/models/__init__.py b/invokeai/backend/model_management/models/__init__.py deleted file mode 100644 index 5f9b13b96f1..00000000000 --- a/invokeai/backend/model_management/models/__init__.py +++ /dev/null @@ -1,167 +0,0 @@ -import inspect -from enum import Enum -from typing import Literal, get_origin - -from pydantic import BaseModel, ConfigDict, create_model - -from .base import ( # noqa: F401 - BaseModelType, - DuplicateModelException, - InvalidModelException, - ModelBase, - ModelConfigBase, - ModelError, - ModelNotFoundException, - ModelType, - ModelVariantType, - SchedulerPredictionType, - SilenceWarnings, - SubModelType, -) -from .clip_vision import CLIPVisionModel -from .controlnet import ControlNetModel # TODO: -from .ip_adapter import IPAdapterModel -from .lora import LoRAModel -from .sdxl import StableDiffusionXLModel -from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model -from .stable_diffusion_onnx import ONNXStableDiffusion1Model, ONNXStableDiffusion2Model -from .t2i_adapter import T2IAdapterModel -from .textual_inversion import TextualInversionModel -from .vae import VaeModel - -MODEL_CLASSES = { - BaseModelType.StableDiffusion1: { - ModelType.ONNX: ONNXStableDiffusion1Model, - ModelType.Main: StableDiffusion1Model, - ModelType.Vae: VaeModel, - ModelType.Lora: LoRAModel, - ModelType.ControlNet: ControlNetModel, - ModelType.TextualInversion: TextualInversionModel, - ModelType.IPAdapter: IPAdapterModel, - ModelType.CLIPVision: CLIPVisionModel, - ModelType.T2IAdapter: T2IAdapterModel, - }, - BaseModelType.StableDiffusion2: { - ModelType.ONNX: ONNXStableDiffusion2Model, - ModelType.Main: StableDiffusion2Model, - ModelType.Vae: VaeModel, - ModelType.Lora: LoRAModel, - ModelType.ControlNet: ControlNetModel, - ModelType.TextualInversion: TextualInversionModel, - ModelType.IPAdapter: IPAdapterModel, - ModelType.CLIPVision: CLIPVisionModel, - ModelType.T2IAdapter: T2IAdapterModel, - }, - BaseModelType.StableDiffusionXL: { - ModelType.Main: StableDiffusionXLModel, - ModelType.Vae: VaeModel, - # will not work until support written - ModelType.Lora: LoRAModel, - ModelType.ControlNet: ControlNetModel, - ModelType.TextualInversion: TextualInversionModel, - ModelType.ONNX: ONNXStableDiffusion2Model, - ModelType.IPAdapter: IPAdapterModel, - ModelType.CLIPVision: CLIPVisionModel, - ModelType.T2IAdapter: T2IAdapterModel, - }, - BaseModelType.StableDiffusionXLRefiner: { - ModelType.Main: StableDiffusionXLModel, - ModelType.Vae: VaeModel, - # will not work until support written - ModelType.Lora: LoRAModel, - ModelType.ControlNet: ControlNetModel, - ModelType.TextualInversion: TextualInversionModel, - ModelType.ONNX: ONNXStableDiffusion2Model, - ModelType.IPAdapter: IPAdapterModel, - ModelType.CLIPVision: CLIPVisionModel, - ModelType.T2IAdapter: T2IAdapterModel, - }, - BaseModelType.Any: { - ModelType.CLIPVision: CLIPVisionModel, - # The following model types are not expected to be used with BaseModelType.Any. - ModelType.ONNX: ONNXStableDiffusion2Model, - ModelType.Main: StableDiffusion2Model, - ModelType.Vae: VaeModel, - ModelType.Lora: LoRAModel, - ModelType.ControlNet: ControlNetModel, - ModelType.TextualInversion: TextualInversionModel, - ModelType.IPAdapter: IPAdapterModel, - ModelType.T2IAdapter: T2IAdapterModel, - }, - # BaseModelType.Kandinsky2_1: { - # ModelType.Main: Kandinsky2_1Model, - # ModelType.MoVQ: MoVQModel, - # ModelType.Lora: LoRAModel, - # ModelType.ControlNet: ControlNetModel, - # ModelType.TextualInversion: TextualInversionModel, - # }, -} - -MODEL_CONFIGS = [] -OPENAPI_MODEL_CONFIGS = [] - - -class OpenAPIModelInfoBase(BaseModel): - model_name: str - base_model: BaseModelType - model_type: ModelType - - model_config = ConfigDict(protected_namespaces=()) - - -for _base_model, models in MODEL_CLASSES.items(): - for model_type, model_class in models.items(): - model_configs = set(model_class._get_configs().values()) - model_configs.discard(None) - MODEL_CONFIGS.extend(model_configs) - - # LS: sort to get the checkpoint configs first, which makes - # for a better template in the Swagger docs - for cfg in sorted(model_configs, key=lambda x: str(x)): - model_name, cfg_name = cfg.__qualname__.split(".")[-2:] - openapi_cfg_name = model_name + cfg_name - if openapi_cfg_name in vars(): - continue - - api_wrapper = create_model( - openapi_cfg_name, - __base__=(cfg, OpenAPIModelInfoBase), - model_type=(Literal[model_type], model_type), # type: ignore - ) - vars()[openapi_cfg_name] = api_wrapper - OPENAPI_MODEL_CONFIGS.append(api_wrapper) - - -def get_model_config_enums(): - enums = [] - - for model_config in MODEL_CONFIGS: - if hasattr(inspect, "get_annotations"): - fields = inspect.get_annotations(model_config) - else: - fields = model_config.__annotations__ - try: - field = fields["model_format"] - except Exception: - raise Exception("format field not found") - - # model_format: None - # model_format: SomeModelFormat - # model_format: Literal[SomeModelFormat.Diffusers] - # model_format: Literal[SomeModelFormat.Diffusers, SomeModelFormat.Checkpoint] - - if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum): - enums.append(field) - - elif get_origin(field) is Literal and all( - isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__ - ): - enums.append(type(field.__args__[0])) - - elif field is None: - pass - - else: - raise Exception(f"Unsupported format definition in {model_configs.__qualname__}") - - return enums diff --git a/invokeai/backend/model_management/models/base.py b/invokeai/backend/model_management/models/base.py deleted file mode 100644 index 7807cb9a542..00000000000 --- a/invokeai/backend/model_management/models/base.py +++ /dev/null @@ -1,681 +0,0 @@ -import inspect -import json -import os -import sys -import typing -import warnings -from abc import ABCMeta, abstractmethod -from contextlib import suppress -from enum import Enum -from pathlib import Path -from typing import Any, Callable, Dict, Generic, List, Literal, Optional, Type, TypeVar, Union - -import numpy as np -import onnx -import safetensors.torch -import torch -from diffusers import ConfigMixin, DiffusionPipeline -from diffusers import logging as diffusers_logging -from onnx import numpy_helper -from onnxruntime import InferenceSession, SessionOptions, get_available_providers -from picklescan.scanner import scan_file_path -from pydantic import BaseModel, ConfigDict, Field -from transformers import logging as transformers_logging - - -class DuplicateModelException(Exception): - pass - - -class InvalidModelException(Exception): - pass - - -class ModelNotFoundException(Exception): - pass - - -class BaseModelType(str, Enum): - Any = "any" # For models that are not associated with any particular base model. - StableDiffusion1 = "sd-1" - StableDiffusion2 = "sd-2" - StableDiffusionXL = "sdxl" - StableDiffusionXLRefiner = "sdxl-refiner" - # Kandinsky2_1 = "kandinsky-2.1" - - -class ModelType(str, Enum): - ONNX = "onnx" - Main = "main" - Vae = "vae" - Lora = "lora" - ControlNet = "controlnet" # used by model_probe - TextualInversion = "embedding" - IPAdapter = "ip_adapter" - CLIPVision = "clip_vision" - T2IAdapter = "t2i_adapter" - - -class SubModelType(str, Enum): - UNet = "unet" - TextEncoder = "text_encoder" - TextEncoder2 = "text_encoder_2" - Tokenizer = "tokenizer" - Tokenizer2 = "tokenizer_2" - Vae = "vae" - VaeDecoder = "vae_decoder" - VaeEncoder = "vae_encoder" - Scheduler = "scheduler" - SafetyChecker = "safety_checker" - # MoVQ = "movq" - - -class ModelVariantType(str, Enum): - Normal = "normal" - Inpaint = "inpaint" - Depth = "depth" - - -class SchedulerPredictionType(str, Enum): - Epsilon = "epsilon" - VPrediction = "v_prediction" - Sample = "sample" - - -class ModelError(str, Enum): - NotFound = "not_found" - - -def model_config_json_schema_extra(schema: dict[str, Any]) -> None: - if "required" not in schema: - schema["required"] = [] - schema["required"].append("model_type") - - -class ModelConfigBase(BaseModel): - path: str # or Path - description: Optional[str] = Field(None) - model_format: Optional[str] = Field(None) - error: Optional[ModelError] = Field(None) - - model_config = ConfigDict( - use_enum_values=True, protected_namespaces=(), json_schema_extra=model_config_json_schema_extra - ) - - -class EmptyConfigLoader(ConfigMixin): - @classmethod - def load_config(cls, *args, **kwargs): - cls.config_name = kwargs.pop("config_name") - return super().load_config(*args, **kwargs) - - -T_co = TypeVar("T_co", covariant=True) - - -class classproperty(Generic[T_co]): - def __init__(self, fget: Callable[[Any], T_co]) -> None: - self.fget = fget - - def __get__(self, instance: Optional[Any], owner: Type[Any]) -> T_co: - return self.fget(owner) - - def __set__(self, instance: Optional[Any], value: Any) -> None: - raise AttributeError("cannot set attribute") - - -class ModelBase(metaclass=ABCMeta): - # model_path: str - # base_model: BaseModelType - # model_type: ModelType - - def __init__( - self, - model_path: str, - base_model: BaseModelType, - model_type: ModelType, - ): - self.model_path = model_path - self.base_model = base_model - self.model_type = model_type - - def _hf_definition_to_type(self, subtypes: List[str]) -> Type: - if len(subtypes) < 2: - raise Exception("Invalid subfolder definition!") - if all(t is None for t in subtypes): - return None - elif any(t is None for t in subtypes): - raise Exception(f"Unsupported definition: {subtypes}") - - if subtypes[0] in ["diffusers", "transformers"]: - res_type = sys.modules[subtypes[0]] - subtypes = subtypes[1:] - - else: - res_type = sys.modules["diffusers"] - res_type = res_type.pipelines - - for subtype in subtypes: - res_type = getattr(res_type, subtype) - return res_type - - @classmethod - def _get_configs(cls): - with suppress(Exception): - return cls.__configs - - configs = {} - for name in dir(cls): - if name.startswith("__"): - continue - - value = getattr(cls, name) - if not isinstance(value, type) or not issubclass(value, ModelConfigBase): - continue - - if hasattr(inspect, "get_annotations"): - fields = inspect.get_annotations(value) - else: - fields = value.__annotations__ - try: - field = fields["model_format"] - except Exception: - raise Exception(f"Invalid config definition - format field not found({cls.__qualname__})") - - if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum): - for model_format in field: - configs[model_format.value] = value - - elif typing.get_origin(field) is Literal and all( - isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__ - ): - for model_format in field.__args__: - configs[model_format.value] = value - - elif field is None: - configs[None] = value - - else: - raise Exception(f"Unsupported format definition in {cls.__qualname__}") - - cls.__configs = configs - return cls.__configs - - @classmethod - def create_config(cls, **kwargs) -> ModelConfigBase: - if "model_format" not in kwargs: - raise Exception("Field 'model_format' not found in model config") - - configs = cls._get_configs() - return configs[kwargs["model_format"]](**kwargs) - - @classmethod - def probe_config(cls, path: str, **kwargs) -> ModelConfigBase: - return cls.create_config( - path=path, - model_format=cls.detect_format(path), - ) - - @classmethod - @abstractmethod - def detect_format(cls, path: str) -> str: - raise NotImplementedError() - - @classproperty - @abstractmethod - def save_to_config(cls) -> bool: - raise NotImplementedError() - - @abstractmethod - def get_size(self, child_type: Optional[SubModelType] = None) -> int: - raise NotImplementedError() - - @abstractmethod - def get_model( - self, - torch_dtype: Optional[torch.dtype], - child_type: Optional[SubModelType] = None, - ) -> Any: - raise NotImplementedError() - - -class DiffusersModel(ModelBase): - # child_types: Dict[str, Type] - # child_sizes: Dict[str, int] - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - super().__init__(model_path, base_model, model_type) - - self.child_types: Dict[str, Type] = {} - self.child_sizes: Dict[str, int] = {} - - try: - config_data = DiffusionPipeline.load_config(self.model_path) - # config_data = json.loads(os.path.join(self.model_path, "model_index.json")) - except Exception: - raise Exception("Invalid diffusers model! (model_index.json not found or invalid)") - - config_data.pop("_ignore_files", None) - - # retrieve all folder_names that contain relevant files - child_components = [k for k, v in config_data.items() if isinstance(v, list)] - - for child_name in child_components: - child_type = self._hf_definition_to_type(config_data[child_name]) - self.child_types[child_name] = child_type - self.child_sizes[child_name] = calc_model_size_by_fs(self.model_path, subfolder=child_name) - - def get_size(self, child_type: Optional[SubModelType] = None): - if child_type is None: - return sum(self.child_sizes.values()) - else: - return self.child_sizes[child_type] - - def get_model( - self, - torch_dtype: Optional[torch.dtype], - child_type: Optional[SubModelType] = None, - ): - # return pipeline in different function to pass more arguments - if child_type is None: - raise Exception("Child model type can't be null on diffusers model") - if child_type not in self.child_types: - return None # TODO: or raise - - if torch_dtype == torch.float16: - variants = ["fp16", None] - else: - variants = [None, "fp16"] - - # TODO: better error handling(differentiate not found from others) - for variant in variants: - try: - # TODO: set cache_dir to /dev/null to be sure that cache not used? - model = self.child_types[child_type].from_pretrained( - self.model_path, - subfolder=child_type.value, - torch_dtype=torch_dtype, - variant=variant, - local_files_only=True, - ) - break - except Exception as e: - if not str(e).startswith("Error no file"): - print("====ERR LOAD====") - print(f"{variant}: {e}") - pass - else: - raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model") - - # calc more accurate size - self.child_sizes[child_type] = calc_model_size_by_data(model) - return model - - # def convert_if_required(model_path: str, cache_path: str, config: Optional[dict]) -> str: - - -def calc_model_size_by_fs(model_path: str, subfolder: Optional[str] = None, variant: Optional[str] = None): - if subfolder is not None: - model_path = os.path.join(model_path, subfolder) - - # this can happen when, for example, the safety checker - # is not downloaded. - if not os.path.exists(model_path): - return 0 - - all_files = os.listdir(model_path) - all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))] - - fp16_files = {f for f in all_files if ".fp16." in f or ".fp16-" in f} - bit8_files = {f for f in all_files if ".8bit." in f or ".8bit-" in f} - other_files = set(all_files) - fp16_files - bit8_files - - if variant is None: - files = other_files - elif variant == "fp16": - files = fp16_files - elif variant == "8bit": - files = bit8_files - else: - raise NotImplementedError(f"Unknown variant: {variant}") - - # try read from index if exists - index_postfix = ".index.json" - if variant is not None: - index_postfix = f".index.{variant}.json" - - for file in files: - if not file.endswith(index_postfix): - continue - try: - with open(os.path.join(model_path, file), "r") as f: - index_data = json.loads(f.read()) - return int(index_data["metadata"]["total_size"]) - except Exception: - pass - - # calculate files size if there is no index file - formats = [ - (".safetensors",), # safetensors - (".bin",), # torch - (".onnx", ".pb"), # onnx - (".msgpack",), # flax - (".ckpt",), # tf - (".h5",), # tf2 - ] - - for file_format in formats: - model_files = [f for f in files if f.endswith(file_format)] - if len(model_files) == 0: - continue - - model_size = 0 - for model_file in model_files: - file_stats = os.stat(os.path.join(model_path, model_file)) - model_size += file_stats.st_size - return model_size - - # raise NotImplementedError(f"Unknown model structure! Files: {all_files}") - return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu - - -def calc_model_size_by_data(model) -> int: - if isinstance(model, DiffusionPipeline): - return _calc_pipeline_by_data(model) - elif isinstance(model, torch.nn.Module): - return _calc_model_by_data(model) - elif isinstance(model, IAIOnnxRuntimeModel): - return _calc_onnx_model_by_data(model) - else: - return 0 - - -def _calc_pipeline_by_data(pipeline) -> int: - res = 0 - for submodel_key in pipeline.components.keys(): - submodel = getattr(pipeline, submodel_key) - if submodel is not None and isinstance(submodel, torch.nn.Module): - res += _calc_model_by_data(submodel) - return res - - -def _calc_model_by_data(model) -> int: - mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()]) - mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()]) - mem = mem_params + mem_bufs # in bytes - return mem - - -def _calc_onnx_model_by_data(model) -> int: - tensor_size = model.tensors.size() * 2 # The session doubles this - mem = tensor_size # in bytes - return mem - - -def _fast_safetensors_reader(path: str): - checkpoint = {} - device = torch.device("meta") - with open(path, "rb") as f: - definition_len = int.from_bytes(f.read(8), "little") - definition_json = f.read(definition_len) - definition = json.loads(definition_json) - - if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in { - "pt", - "torch", - "pytorch", - }: - raise Exception("Supported only pytorch safetensors files") - definition.pop("__metadata__", None) - - for key, info in definition.items(): - dtype = { - "I8": torch.int8, - "I16": torch.int16, - "I32": torch.int32, - "I64": torch.int64, - "F16": torch.float16, - "F32": torch.float32, - "F64": torch.float64, - }[info["dtype"]] - - checkpoint[key] = torch.empty(info["shape"], dtype=dtype, device=device) - - return checkpoint - - -def read_checkpoint_meta(path: Union[str, Path], scan: bool = False): - if str(path).endswith(".safetensors"): - try: - checkpoint = _fast_safetensors_reader(path) - except Exception: - # TODO: create issue for support "meta"? - checkpoint = safetensors.torch.load_file(path, device="cpu") - else: - if scan: - scan_result = scan_file_path(path) - if scan_result.infected_files != 0: - raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.') - checkpoint = torch.load(path, map_location=torch.device("meta")) - return checkpoint - - -class SilenceWarnings(object): - def __init__(self): - self.transformers_verbosity = transformers_logging.get_verbosity() - self.diffusers_verbosity = diffusers_logging.get_verbosity() - - def __enter__(self): - transformers_logging.set_verbosity_error() - diffusers_logging.set_verbosity_error() - warnings.simplefilter("ignore") - - def __exit__(self, type, value, traceback): - transformers_logging.set_verbosity(self.transformers_verbosity) - diffusers_logging.set_verbosity(self.diffusers_verbosity) - warnings.simplefilter("default") - - -ONNX_WEIGHTS_NAME = "model.onnx" - - -class IAIOnnxRuntimeModel: - class _tensor_access: - def __init__(self, model): - self.model = model - self.indexes = {} - for idx, obj in enumerate(self.model.proto.graph.initializer): - self.indexes[obj.name] = idx - - def __getitem__(self, key: str): - value = self.model.proto.graph.initializer[self.indexes[key]] - return numpy_helper.to_array(value) - - def __setitem__(self, key: str, value: np.ndarray): - new_node = numpy_helper.from_array(value) - # set_external_data(new_node, location="in-memory-location") - new_node.name = key - # new_node.ClearField("raw_data") - del self.model.proto.graph.initializer[self.indexes[key]] - self.model.proto.graph.initializer.insert(self.indexes[key], new_node) - # self.model.data[key] = OrtValue.ortvalue_from_numpy(value) - - # __delitem__ - - def __contains__(self, key: str): - return self.indexes[key] in self.model.proto.graph.initializer - - def items(self): - raise NotImplementedError("tensor.items") - # return [(obj.name, obj) for obj in self.raw_proto] - - def keys(self): - return self.indexes.keys() - - def values(self): - raise NotImplementedError("tensor.values") - # return [obj for obj in self.raw_proto] - - def size(self): - bytesSum = 0 - for node in self.model.proto.graph.initializer: - bytesSum += sys.getsizeof(node.raw_data) - return bytesSum - - class _access_helper: - def __init__(self, raw_proto): - self.indexes = {} - self.raw_proto = raw_proto - for idx, obj in enumerate(raw_proto): - self.indexes[obj.name] = idx - - def __getitem__(self, key: str): - return self.raw_proto[self.indexes[key]] - - def __setitem__(self, key: str, value): - index = self.indexes[key] - del self.raw_proto[index] - self.raw_proto.insert(index, value) - - # __delitem__ - - def __contains__(self, key: str): - return key in self.indexes - - def items(self): - return [(obj.name, obj) for obj in self.raw_proto] - - def keys(self): - return self.indexes.keys() - - def values(self): - return list(self.raw_proto) - - def __init__(self, model_path: str, provider: Optional[str]): - self.path = model_path - self.session = None - self.provider = provider - """ - self.data_path = self.path + "_data" - if not os.path.exists(self.data_path): - print(f"Moving model tensors to separate file: {self.data_path}") - tmp_proto = onnx.load(model_path, load_external_data=True) - onnx.save_model(tmp_proto, self.path, save_as_external_data=True, all_tensors_to_one_file=True, location=os.path.basename(self.data_path), size_threshold=1024, convert_attribute=False) - del tmp_proto - gc.collect() - - self.proto = onnx.load(model_path, load_external_data=False) - """ - - self.proto = onnx.load(model_path, load_external_data=True) - # self.data = dict() - # for tensor in self.proto.graph.initializer: - # name = tensor.name - - # if tensor.HasField("raw_data"): - # npt = numpy_helper.to_array(tensor) - # orv = OrtValue.ortvalue_from_numpy(npt) - # # self.data[name] = orv - # # set_external_data(tensor, location="in-memory-location") - # tensor.name = name - # # tensor.ClearField("raw_data") - - self.nodes = self._access_helper(self.proto.graph.node) - # self.initializers = self._access_helper(self.proto.graph.initializer) - # print(self.proto.graph.input) - # print(self.proto.graph.initializer) - - self.tensors = self._tensor_access(self) - - # TODO: integrate with model manager/cache - def create_session(self, height=None, width=None): - if self.session is None or self.session_width != width or self.session_height != height: - # onnx.save(self.proto, "tmp.onnx") - # onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False) - # TODO: something to be able to get weight when they already moved outside of model proto - # (trimmed_model, external_data) = buffer_external_data_tensors(self.proto) - sess = SessionOptions() - # self._external_data.update(**external_data) - # sess.add_external_initializers(list(self.data.keys()), list(self.data.values())) - # sess.enable_profiling = True - - # sess.intra_op_num_threads = 1 - # sess.inter_op_num_threads = 1 - # sess.execution_mode = ExecutionMode.ORT_SEQUENTIAL - # sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL - # sess.enable_cpu_mem_arena = True - # sess.enable_mem_pattern = True - # sess.add_session_config_entry("session.intra_op.use_xnnpack_threadpool", "1") ########### It's the key code - self.session_height = height - self.session_width = width - if height and width: - sess.add_free_dimension_override_by_name("unet_sample_batch", 2) - sess.add_free_dimension_override_by_name("unet_sample_channels", 4) - sess.add_free_dimension_override_by_name("unet_hidden_batch", 2) - sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77) - sess.add_free_dimension_override_by_name("unet_sample_height", self.session_height) - sess.add_free_dimension_override_by_name("unet_sample_width", self.session_width) - sess.add_free_dimension_override_by_name("unet_time_batch", 1) - providers = [] - if self.provider: - providers.append(self.provider) - else: - providers = get_available_providers() - if "TensorrtExecutionProvider" in providers: - providers.remove("TensorrtExecutionProvider") - try: - self.session = InferenceSession(self.proto.SerializeToString(), providers=providers, sess_options=sess) - except Exception as e: - raise e - # self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options) - # self.io_binding = self.session.io_binding() - - def release_session(self): - self.session = None - import gc - - gc.collect() - return - - def __call__(self, **kwargs): - if self.session is None: - raise Exception("You should call create_session before running model") - - inputs = {k: np.array(v) for k, v in kwargs.items()} - # output_names = self.session.get_outputs() - # for k in inputs: - # self.io_binding.bind_cpu_input(k, inputs[k]) - # for name in output_names: - # self.io_binding.bind_output(name.name) - # self.session.run_with_iobinding(self.io_binding, None) - # return self.io_binding.copy_outputs_to_cpu() - return self.session.run(None, inputs) - - # compatability with diffusers load code - @classmethod - def from_pretrained( - cls, - model_id: Union[str, Path], - subfolder: Union[str, Path] = None, - file_name: Optional[str] = None, - provider: Optional[str] = None, - sess_options: Optional["SessionOptions"] = None, - **kwargs, - ): - file_name = file_name or ONNX_WEIGHTS_NAME - - if os.path.isdir(model_id): - model_path = model_id - if subfolder is not None: - model_path = os.path.join(model_path, subfolder) - model_path = os.path.join(model_path, file_name) - - else: - model_path = model_id - - # load model from local directory - if not os.path.isfile(model_path): - raise Exception(f"Model not found: {model_path}") - - # TODO: session options - return cls(model_path, provider=provider) diff --git a/invokeai/backend/model_management/models/clip_vision.py b/invokeai/backend/model_management/models/clip_vision.py deleted file mode 100644 index 2276c6beed1..00000000000 --- a/invokeai/backend/model_management/models/clip_vision.py +++ /dev/null @@ -1,82 +0,0 @@ -import os -from enum import Enum -from typing import Literal, Optional - -import torch -from transformers import CLIPVisionModelWithProjection - -from invokeai.backend.model_management.models.base import ( - BaseModelType, - InvalidModelException, - ModelBase, - ModelConfigBase, - ModelType, - SubModelType, - calc_model_size_by_data, - calc_model_size_by_fs, - classproperty, -) - - -class CLIPVisionModelFormat(str, Enum): - Diffusers = "diffusers" - - -class CLIPVisionModel(ModelBase): - class DiffusersConfig(ModelConfigBase): - model_format: Literal[CLIPVisionModelFormat.Diffusers] - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert model_type == ModelType.CLIPVision - super().__init__(model_path, base_model, model_type) - - self.model_size = calc_model_size_by_fs(self.model_path) - - @classmethod - def detect_format(cls, path: str) -> str: - if not os.path.exists(path): - raise ModuleNotFoundError(f"No CLIP Vision model at path '{path}'.") - - if os.path.isdir(path) and os.path.exists(os.path.join(path, "config.json")): - return CLIPVisionModelFormat.Diffusers - - raise InvalidModelException(f"Unexpected CLIP Vision model format: {path}") - - @classproperty - def save_to_config(cls) -> bool: - return True - - def get_size(self, child_type: Optional[SubModelType] = None) -> int: - if child_type is not None: - raise ValueError("There are no child models in a CLIP Vision model.") - - return self.model_size - - def get_model( - self, - torch_dtype: Optional[torch.dtype], - child_type: Optional[SubModelType] = None, - ) -> CLIPVisionModelWithProjection: - if child_type is not None: - raise ValueError("There are no child models in a CLIP Vision model.") - - model = CLIPVisionModelWithProjection.from_pretrained(self.model_path, torch_dtype=torch_dtype) - - # Calculate a more accurate model size. - self.model_size = calc_model_size_by_data(model) - - return model - - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, - ) -> str: - format = cls.detect_format(model_path) - if format == CLIPVisionModelFormat.Diffusers: - return model_path - else: - raise ValueError(f"Unsupported format: '{format}'.") diff --git a/invokeai/backend/model_management/models/controlnet.py b/invokeai/backend/model_management/models/controlnet.py deleted file mode 100644 index da269eba4b7..00000000000 --- a/invokeai/backend/model_management/models/controlnet.py +++ /dev/null @@ -1,163 +0,0 @@ -import os -from enum import Enum -from pathlib import Path -from typing import Literal, Optional - -import torch - -import invokeai.backend.util.logging as logger -from invokeai.app.services.config import InvokeAIAppConfig - -from .base import ( - BaseModelType, - EmptyConfigLoader, - InvalidModelException, - ModelBase, - ModelConfigBase, - ModelNotFoundException, - ModelType, - SubModelType, - calc_model_size_by_data, - calc_model_size_by_fs, - classproperty, -) - - -class ControlNetModelFormat(str, Enum): - Checkpoint = "checkpoint" - Diffusers = "diffusers" - - -class ControlNetModel(ModelBase): - # model_class: Type - # model_size: int - - class DiffusersConfig(ModelConfigBase): - model_format: Literal[ControlNetModelFormat.Diffusers] - - class CheckpointConfig(ModelConfigBase): - model_format: Literal[ControlNetModelFormat.Checkpoint] - config: str - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert model_type == ModelType.ControlNet - super().__init__(model_path, base_model, model_type) - - try: - config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json") - # config = json.loads(os.path.join(self.model_path, "config.json")) - except Exception: - raise Exception("Invalid controlnet model! (config.json not found or invalid)") - - model_class_name = config.get("_class_name", None) - if model_class_name not in {"ControlNetModel"}: - raise Exception(f"Invalid ControlNet model! Unknown _class_name: {model_class_name}") - - try: - self.model_class = self._hf_definition_to_type(["diffusers", model_class_name]) - self.model_size = calc_model_size_by_fs(self.model_path) - except Exception: - raise Exception("Invalid ControlNet model!") - - def get_size(self, child_type: Optional[SubModelType] = None): - if child_type is not None: - raise Exception("There is no child models in controlnet model") - return self.model_size - - def get_model( - self, - torch_dtype: Optional[torch.dtype], - child_type: Optional[SubModelType] = None, - ): - if child_type is not None: - raise Exception("There are no child models in controlnet model") - - model = None - for variant in ["fp16", None]: - try: - model = self.model_class.from_pretrained( - self.model_path, - torch_dtype=torch_dtype, - variant=variant, - ) - break - except Exception: - pass - if not model: - raise ModelNotFoundException() - - # calc more accurate size - self.model_size = calc_model_size_by_data(model) - return model - - @classproperty - def save_to_config(cls) -> bool: - return False - - @classmethod - def detect_format(cls, path: str): - if not os.path.exists(path): - raise ModelNotFoundException() - - if os.path.isdir(path): - if os.path.exists(os.path.join(path, "config.json")): - return ControlNetModelFormat.Diffusers - - if os.path.isfile(path): - if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "pth"]): - return ControlNetModelFormat.Checkpoint - - raise InvalidModelException(f"Not a valid model: {path}") - - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, - ) -> str: - if cls.detect_format(model_path) == ControlNetModelFormat.Checkpoint: - return _convert_controlnet_ckpt_and_cache( - model_path=model_path, - model_config=config.config, - output_path=output_path, - base_model=base_model, - ) - else: - return model_path - - -def _convert_controlnet_ckpt_and_cache( - model_path: str, - output_path: str, - base_model: BaseModelType, - model_config: str, -) -> str: - """ - Convert the controlnet from checkpoint format to diffusers format, - cache it to disk, and return Path to converted - file. If already on disk then just returns Path. - """ - print(f"DEBUG: controlnet config = {model_config}") - app_config = InvokeAIAppConfig.get_config() - weights = app_config.root_path / model_path - output_path = Path(output_path) - - logger.info(f"Converting {weights} to diffusers format") - # return cached version if it exists - if output_path.exists(): - return output_path - - # to avoid circular import errors - from ..convert_ckpt_to_diffusers import convert_controlnet_to_diffusers - - convert_controlnet_to_diffusers( - weights, - output_path, - original_config_file=app_config.root_path / model_config, - image_size=512, - scan_needed=True, - from_safetensors=weights.suffix == ".safetensors", - ) - return output_path diff --git a/invokeai/backend/model_management/models/ip_adapter.py b/invokeai/backend/model_management/models/ip_adapter.py deleted file mode 100644 index c60edd0abe3..00000000000 --- a/invokeai/backend/model_management/models/ip_adapter.py +++ /dev/null @@ -1,98 +0,0 @@ -import os -import typing -from enum import Enum -from typing import Literal, Optional - -import torch - -from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus, build_ip_adapter -from invokeai.backend.model_management.models.base import ( - BaseModelType, - InvalidModelException, - ModelBase, - ModelConfigBase, - ModelType, - SubModelType, - calc_model_size_by_fs, - classproperty, -) - - -class IPAdapterModelFormat(str, Enum): - # The custom IP-Adapter model format defined by InvokeAI. - InvokeAI = "invokeai" - - -class IPAdapterModel(ModelBase): - class InvokeAIConfig(ModelConfigBase): - model_format: Literal[IPAdapterModelFormat.InvokeAI] - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert model_type == ModelType.IPAdapter - super().__init__(model_path, base_model, model_type) - - self.model_size = calc_model_size_by_fs(self.model_path) - - @classmethod - def detect_format(cls, path: str) -> str: - if not os.path.exists(path): - raise ModuleNotFoundError(f"No IP-Adapter model at path '{path}'.") - - if os.path.isdir(path): - model_file = os.path.join(path, "ip_adapter.bin") - image_encoder_config_file = os.path.join(path, "image_encoder.txt") - if os.path.exists(model_file) and os.path.exists(image_encoder_config_file): - return IPAdapterModelFormat.InvokeAI - - raise InvalidModelException(f"Unexpected IP-Adapter model format: {path}") - - @classproperty - def save_to_config(cls) -> bool: - return True - - def get_size(self, child_type: Optional[SubModelType] = None) -> int: - if child_type is not None: - raise ValueError("There are no child models in an IP-Adapter model.") - - return self.model_size - - def get_model( - self, - torch_dtype: torch.dtype, - child_type: Optional[SubModelType] = None, - ) -> typing.Union[IPAdapter, IPAdapterPlus]: - if child_type is not None: - raise ValueError("There are no child models in an IP-Adapter model.") - - model = build_ip_adapter( - ip_adapter_ckpt_path=os.path.join(self.model_path, "ip_adapter.bin"), - device=torch.device("cpu"), - dtype=torch_dtype, - ) - - self.model_size = model.calc_size() - return model - - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, - ) -> str: - format = cls.detect_format(model_path) - if format == IPAdapterModelFormat.InvokeAI: - return model_path - else: - raise ValueError(f"Unsupported format: '{format}'.") - - -def get_ip_adapter_image_encoder_model_id(model_path: str): - """Read the ID of the image encoder associated with the IP-Adapter at `model_path`.""" - image_encoder_config_file = os.path.join(model_path, "image_encoder.txt") - - with open(image_encoder_config_file, "r") as f: - image_encoder_model = f.readline().strip() - - return image_encoder_model diff --git a/invokeai/backend/model_management/models/sdxl.py b/invokeai/backend/model_management/models/sdxl.py deleted file mode 100644 index 01e9420fed7..00000000000 --- a/invokeai/backend/model_management/models/sdxl.py +++ /dev/null @@ -1,148 +0,0 @@ -import json -import os -from enum import Enum -from pathlib import Path -from typing import Literal, Optional - -from omegaconf import OmegaConf -from pydantic import Field - -from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_management.detect_baked_in_vae import has_baked_in_sdxl_vae -from invokeai.backend.util.logging import InvokeAILogger - -from .base import ( - BaseModelType, - DiffusersModel, - InvalidModelException, - ModelConfigBase, - ModelType, - ModelVariantType, - classproperty, - read_checkpoint_meta, -) - - -class StableDiffusionXLModelFormat(str, Enum): - Checkpoint = "checkpoint" - Diffusers = "diffusers" - - -class StableDiffusionXLModel(DiffusersModel): - # TODO: check that configs overwriten properly - class DiffusersConfig(ModelConfigBase): - model_format: Literal[StableDiffusionXLModelFormat.Diffusers] - vae: Optional[str] = Field(None) - variant: ModelVariantType - - class CheckpointConfig(ModelConfigBase): - model_format: Literal[StableDiffusionXLModelFormat.Checkpoint] - vae: Optional[str] = Field(None) - config: str - variant: ModelVariantType - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert base_model in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner} - assert model_type == ModelType.Main - super().__init__( - model_path=model_path, - base_model=BaseModelType.StableDiffusionXL, - model_type=ModelType.Main, - ) - - @classmethod - def probe_config(cls, path: str, **kwargs): - model_format = cls.detect_format(path) - ckpt_config_path = kwargs.get("config", None) - if model_format == StableDiffusionXLModelFormat.Checkpoint: - if ckpt_config_path: - ckpt_config = OmegaConf.load(ckpt_config_path) - in_channels = ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"] - - else: - checkpoint = read_checkpoint_meta(path) - checkpoint = checkpoint.get("state_dict", checkpoint) - in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1] - - elif model_format == StableDiffusionXLModelFormat.Diffusers: - unet_config_path = os.path.join(path, "unet", "config.json") - if os.path.exists(unet_config_path): - with open(unet_config_path, "r") as f: - unet_config = json.loads(f.read()) - in_channels = unet_config["in_channels"] - - else: - raise InvalidModelException(f"{path} is not a recognized Stable Diffusion diffusers model") - - else: - raise NotImplementedError(f"Unknown stable diffusion 2.* format: {model_format}") - - if in_channels == 9: - variant = ModelVariantType.Inpaint - elif in_channels == 5: - variant = ModelVariantType.Depth - elif in_channels == 4: - variant = ModelVariantType.Normal - else: - raise Exception("Unkown stable diffusion 2.* model format") - - if ckpt_config_path is None: - # avoid circular import - from .stable_diffusion import _select_ckpt_config - - ckpt_config_path = _select_ckpt_config(kwargs.get("model_base", BaseModelType.StableDiffusionXL), variant) - - return cls.create_config( - path=path, - model_format=model_format, - config=ckpt_config_path, - variant=variant, - ) - - @classproperty - def save_to_config(cls) -> bool: - return True - - @classmethod - def detect_format(cls, model_path: str): - if os.path.isdir(model_path): - return StableDiffusionXLModelFormat.Diffusers - else: - return StableDiffusionXLModelFormat.Checkpoint - - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, - ) -> str: - # The convert script adapted from the diffusers package uses - # strings for the base model type. To avoid making too many - # source code changes, we simply translate here - if Path(output_path).exists(): - return output_path - - if isinstance(config, cls.CheckpointConfig): - from invokeai.backend.model_management.models.stable_diffusion import _convert_ckpt_and_cache - - # Hack in VAE-fp16 fix - If model sdxl-vae-fp16-fix is installed, - # then we bake it into the converted model unless there is already - # a nonstandard VAE installed. - kwargs = {} - app_config = InvokeAIAppConfig.get_config() - vae_path = app_config.models_path / "sdxl/vae/sdxl-vae-fp16-fix" - if vae_path.exists() and not has_baked_in_sdxl_vae(Path(model_path)): - InvokeAILogger.get_logger().warning("No baked-in VAE detected. Inserting sdxl-vae-fp16-fix.") - kwargs["vae_path"] = vae_path - - return _convert_ckpt_and_cache( - version=base_model, - model_config=config, - output_path=output_path, - use_safetensors=True, - **kwargs, - ) - else: - return model_path diff --git a/invokeai/backend/model_management/models/stable_diffusion.py b/invokeai/backend/model_management/models/stable_diffusion.py deleted file mode 100644 index a38a44fccf7..00000000000 --- a/invokeai/backend/model_management/models/stable_diffusion.py +++ /dev/null @@ -1,337 +0,0 @@ -import json -import os -from enum import Enum -from pathlib import Path -from typing import Literal, Optional, Union - -from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline -from omegaconf import OmegaConf -from pydantic import Field - -import invokeai.backend.util.logging as logger -from invokeai.app.services.config import InvokeAIAppConfig - -from .base import ( - BaseModelType, - DiffusersModel, - InvalidModelException, - ModelConfigBase, - ModelNotFoundException, - ModelType, - ModelVariantType, - SilenceWarnings, - classproperty, - read_checkpoint_meta, -) -from .sdxl import StableDiffusionXLModel - - -class StableDiffusion1ModelFormat(str, Enum): - Checkpoint = "checkpoint" - Diffusers = "diffusers" - - -class StableDiffusion1Model(DiffusersModel): - class DiffusersConfig(ModelConfigBase): - model_format: Literal[StableDiffusion1ModelFormat.Diffusers] - vae: Optional[str] = Field(None) - variant: ModelVariantType - - class CheckpointConfig(ModelConfigBase): - model_format: Literal[StableDiffusion1ModelFormat.Checkpoint] - vae: Optional[str] = Field(None) - config: str - variant: ModelVariantType - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert base_model == BaseModelType.StableDiffusion1 - assert model_type == ModelType.Main - super().__init__( - model_path=model_path, - base_model=BaseModelType.StableDiffusion1, - model_type=ModelType.Main, - ) - - @classmethod - def probe_config(cls, path: str, **kwargs): - model_format = cls.detect_format(path) - ckpt_config_path = kwargs.get("config", None) - if model_format == StableDiffusion1ModelFormat.Checkpoint: - if ckpt_config_path: - ckpt_config = OmegaConf.load(ckpt_config_path) - ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"] - - else: - checkpoint = read_checkpoint_meta(path) - checkpoint = checkpoint.get("state_dict", checkpoint) - in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1] - - elif model_format == StableDiffusion1ModelFormat.Diffusers: - unet_config_path = os.path.join(path, "unet", "config.json") - if os.path.exists(unet_config_path): - with open(unet_config_path, "r") as f: - unet_config = json.loads(f.read()) - in_channels = unet_config["in_channels"] - - else: - raise NotImplementedError(f"{path} is not a supported stable diffusion diffusers format") - - else: - raise NotImplementedError(f"Unknown stable diffusion 1.* format: {model_format}") - - if in_channels == 9: - variant = ModelVariantType.Inpaint - elif in_channels == 4: - variant = ModelVariantType.Normal - else: - raise Exception("Unkown stable diffusion 1.* model format") - - if ckpt_config_path is None: - ckpt_config_path = _select_ckpt_config(BaseModelType.StableDiffusion1, variant) - - return cls.create_config( - path=path, - model_format=model_format, - config=ckpt_config_path, - variant=variant, - ) - - @classproperty - def save_to_config(cls) -> bool: - return True - - @classmethod - def detect_format(cls, model_path: str): - if not os.path.exists(model_path): - raise ModelNotFoundException() - - if os.path.isdir(model_path): - if os.path.exists(os.path.join(model_path, "model_index.json")): - return StableDiffusion1ModelFormat.Diffusers - - if os.path.isfile(model_path): - if any(model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]): - return StableDiffusion1ModelFormat.Checkpoint - - raise InvalidModelException(f"Not a valid model: {model_path}") - - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, - ) -> str: - if isinstance(config, cls.CheckpointConfig): - return _convert_ckpt_and_cache( - version=BaseModelType.StableDiffusion1, - model_config=config, - load_safety_checker=False, - output_path=output_path, - ) - else: - return model_path - - -class StableDiffusion2ModelFormat(str, Enum): - Checkpoint = "checkpoint" - Diffusers = "diffusers" - - -class StableDiffusion2Model(DiffusersModel): - # TODO: check that configs overwriten properly - class DiffusersConfig(ModelConfigBase): - model_format: Literal[StableDiffusion2ModelFormat.Diffusers] - vae: Optional[str] = Field(None) - variant: ModelVariantType - - class CheckpointConfig(ModelConfigBase): - model_format: Literal[StableDiffusion2ModelFormat.Checkpoint] - vae: Optional[str] = Field(None) - config: str - variant: ModelVariantType - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert base_model == BaseModelType.StableDiffusion2 - assert model_type == ModelType.Main - super().__init__( - model_path=model_path, - base_model=BaseModelType.StableDiffusion2, - model_type=ModelType.Main, - ) - - @classmethod - def probe_config(cls, path: str, **kwargs): - model_format = cls.detect_format(path) - ckpt_config_path = kwargs.get("config", None) - if model_format == StableDiffusion2ModelFormat.Checkpoint: - if ckpt_config_path: - ckpt_config = OmegaConf.load(ckpt_config_path) - ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"] - - else: - checkpoint = read_checkpoint_meta(path) - checkpoint = checkpoint.get("state_dict", checkpoint) - in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1] - - elif model_format == StableDiffusion2ModelFormat.Diffusers: - unet_config_path = os.path.join(path, "unet", "config.json") - if os.path.exists(unet_config_path): - with open(unet_config_path, "r") as f: - unet_config = json.loads(f.read()) - in_channels = unet_config["in_channels"] - - else: - raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)") - - else: - raise NotImplementedError(f"Unknown stable diffusion 2.* format: {model_format}") - - if in_channels == 9: - variant = ModelVariantType.Inpaint - elif in_channels == 5: - variant = ModelVariantType.Depth - elif in_channels == 4: - variant = ModelVariantType.Normal - else: - raise Exception("Unkown stable diffusion 2.* model format") - - if ckpt_config_path is None: - ckpt_config_path = _select_ckpt_config(BaseModelType.StableDiffusion2, variant) - - return cls.create_config( - path=path, - model_format=model_format, - config=ckpt_config_path, - variant=variant, - ) - - @classproperty - def save_to_config(cls) -> bool: - return True - - @classmethod - def detect_format(cls, model_path: str): - if not os.path.exists(model_path): - raise ModelNotFoundException() - - if os.path.isdir(model_path): - if os.path.exists(os.path.join(model_path, "model_index.json")): - return StableDiffusion2ModelFormat.Diffusers - - if os.path.isfile(model_path): - if any(model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]): - return StableDiffusion2ModelFormat.Checkpoint - - raise InvalidModelException(f"Not a valid model: {model_path}") - - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, - ) -> str: - if isinstance(config, cls.CheckpointConfig): - return _convert_ckpt_and_cache( - version=BaseModelType.StableDiffusion2, - model_config=config, - output_path=output_path, - ) - else: - return model_path - - -# TODO: rework -# pass precision - currently defaulting to fp16 -def _convert_ckpt_and_cache( - version: BaseModelType, - model_config: Union[ - StableDiffusion1Model.CheckpointConfig, - StableDiffusion2Model.CheckpointConfig, - StableDiffusionXLModel.CheckpointConfig, - ], - output_path: str, - use_save_model: bool = False, - **kwargs, -) -> str: - """ - Convert the checkpoint model indicated in mconfig into a - diffusers, cache it to disk, and return Path to converted - file. If already on disk then just returns Path. - """ - app_config = InvokeAIAppConfig.get_config() - - weights = app_config.models_path / model_config.path - config_file = app_config.root_path / model_config.config - output_path = Path(output_path) - variant = model_config.variant - pipeline_class = StableDiffusionInpaintPipeline if variant == "inpaint" else StableDiffusionPipeline - - # return cached version if it exists - if output_path.exists(): - return output_path - - # to avoid circular import errors - from ...util.devices import choose_torch_device, torch_dtype - from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers - - model_base_to_model_type = { - BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder", - BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder", - BaseModelType.StableDiffusionXL: "SDXL", - BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner", - } - logger.info(f"Converting {weights} to diffusers format") - with SilenceWarnings(): - convert_ckpt_to_diffusers( - weights, - output_path, - model_type=model_base_to_model_type[version], - model_version=version, - model_variant=model_config.variant, - original_config_file=config_file, - extract_ema=True, - scan_needed=True, - pipeline_class=pipeline_class, - from_safetensors=weights.suffix == ".safetensors", - precision=torch_dtype(choose_torch_device()), - **kwargs, - ) - return output_path - - -def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType): - ckpt_configs = { - BaseModelType.StableDiffusion1: { - ModelVariantType.Normal: "v1-inference.yaml", - ModelVariantType.Inpaint: "v1-inpainting-inference.yaml", - }, - BaseModelType.StableDiffusion2: { - ModelVariantType.Normal: "v2-inference-v.yaml", # best guess, as we can't differentiate with base(512) - ModelVariantType.Inpaint: "v2-inpainting-inference.yaml", - ModelVariantType.Depth: "v2-midas-inference.yaml", - }, - BaseModelType.StableDiffusionXL: { - ModelVariantType.Normal: "sd_xl_base.yaml", - ModelVariantType.Inpaint: None, - ModelVariantType.Depth: None, - }, - BaseModelType.StableDiffusionXLRefiner: { - ModelVariantType.Normal: "sd_xl_refiner.yaml", - ModelVariantType.Inpaint: None, - ModelVariantType.Depth: None, - }, - } - - app_config = InvokeAIAppConfig.get_config() - try: - config_path = app_config.legacy_conf_path / ckpt_configs[version][variant] - if config_path.is_relative_to(app_config.root_path): - config_path = config_path.relative_to(app_config.root_path) - return str(config_path) - - except Exception: - return None diff --git a/invokeai/backend/model_management/models/stable_diffusion_onnx.py b/invokeai/backend/model_management/models/stable_diffusion_onnx.py deleted file mode 100644 index 2d0dd22c43a..00000000000 --- a/invokeai/backend/model_management/models/stable_diffusion_onnx.py +++ /dev/null @@ -1,150 +0,0 @@ -from enum import Enum -from typing import Literal - -from diffusers import OnnxRuntimeModel - -from .base import ( - BaseModelType, - DiffusersModel, - IAIOnnxRuntimeModel, - ModelConfigBase, - ModelType, - ModelVariantType, - SchedulerPredictionType, - classproperty, -) - - -class StableDiffusionOnnxModelFormat(str, Enum): - Olive = "olive" - Onnx = "onnx" - - -class ONNXStableDiffusion1Model(DiffusersModel): - class Config(ModelConfigBase): - model_format: Literal[StableDiffusionOnnxModelFormat.Onnx] - variant: ModelVariantType - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert base_model == BaseModelType.StableDiffusion1 - assert model_type == ModelType.ONNX - super().__init__( - model_path=model_path, - base_model=BaseModelType.StableDiffusion1, - model_type=ModelType.ONNX, - ) - - for child_name, child_type in self.child_types.items(): - if child_type is OnnxRuntimeModel: - self.child_types[child_name] = IAIOnnxRuntimeModel - - # TODO: check that no optimum models provided - - @classmethod - def probe_config(cls, path: str, **kwargs): - model_format = cls.detect_format(path) - in_channels = 4 # TODO: - - if in_channels == 9: - variant = ModelVariantType.Inpaint - elif in_channels == 4: - variant = ModelVariantType.Normal - else: - raise Exception("Unkown stable diffusion 1.* model format") - - return cls.create_config( - path=path, - model_format=model_format, - variant=variant, - ) - - @classproperty - def save_to_config(cls) -> bool: - return True - - @classmethod - def detect_format(cls, model_path: str): - # TODO: Detect onnx vs olive - return StableDiffusionOnnxModelFormat.Onnx - - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, - ) -> str: - return model_path - - -class ONNXStableDiffusion2Model(DiffusersModel): - # TODO: check that configs overwriten properly - class Config(ModelConfigBase): - model_format: Literal[StableDiffusionOnnxModelFormat.Onnx] - variant: ModelVariantType - prediction_type: SchedulerPredictionType - upcast_attention: bool - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert base_model == BaseModelType.StableDiffusion2 - assert model_type == ModelType.ONNX - super().__init__( - model_path=model_path, - base_model=BaseModelType.StableDiffusion2, - model_type=ModelType.ONNX, - ) - - for child_name, child_type in self.child_types.items(): - if child_type is OnnxRuntimeModel: - self.child_types[child_name] = IAIOnnxRuntimeModel - # TODO: check that no optimum models provided - - @classmethod - def probe_config(cls, path: str, **kwargs): - model_format = cls.detect_format(path) - in_channels = 4 # TODO: - - if in_channels == 9: - variant = ModelVariantType.Inpaint - elif in_channels == 5: - variant = ModelVariantType.Depth - elif in_channels == 4: - variant = ModelVariantType.Normal - else: - raise Exception("Unkown stable diffusion 2.* model format") - - if variant == ModelVariantType.Normal: - prediction_type = SchedulerPredictionType.VPrediction - upcast_attention = True - - else: - prediction_type = SchedulerPredictionType.Epsilon - upcast_attention = False - - return cls.create_config( - path=path, - model_format=model_format, - variant=variant, - prediction_type=prediction_type, - upcast_attention=upcast_attention, - ) - - @classproperty - def save_to_config(cls) -> bool: - return True - - @classmethod - def detect_format(cls, model_path: str): - # TODO: Detect onnx vs olive - return StableDiffusionOnnxModelFormat.Onnx - - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, - ) -> str: - return model_path diff --git a/invokeai/backend/model_management/models/t2i_adapter.py b/invokeai/backend/model_management/models/t2i_adapter.py deleted file mode 100644 index 4adb9901f99..00000000000 --- a/invokeai/backend/model_management/models/t2i_adapter.py +++ /dev/null @@ -1,102 +0,0 @@ -import os -from enum import Enum -from typing import Literal, Optional - -import torch -from diffusers import T2IAdapter - -from invokeai.backend.model_management.models.base import ( - BaseModelType, - EmptyConfigLoader, - InvalidModelException, - ModelBase, - ModelConfigBase, - ModelNotFoundException, - ModelType, - SubModelType, - calc_model_size_by_data, - calc_model_size_by_fs, - classproperty, -) - - -class T2IAdapterModelFormat(str, Enum): - Diffusers = "diffusers" - - -class T2IAdapterModel(ModelBase): - class DiffusersConfig(ModelConfigBase): - model_format: Literal[T2IAdapterModelFormat.Diffusers] - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert model_type == ModelType.T2IAdapter - super().__init__(model_path, base_model, model_type) - - config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json") - - model_class_name = config.get("_class_name", None) - if model_class_name not in {"T2IAdapter"}: - raise InvalidModelException(f"Invalid T2I-Adapter model. Unknown _class_name: '{model_class_name}'.") - - self.model_class = self._hf_definition_to_type(["diffusers", model_class_name]) - self.model_size = calc_model_size_by_fs(self.model_path) - - def get_size(self, child_type: Optional[SubModelType] = None): - if child_type is not None: - raise ValueError(f"T2I-Adapters do not have child models. Invalid child type: '{child_type}'.") - return self.model_size - - def get_model( - self, - torch_dtype: Optional[torch.dtype], - child_type: Optional[SubModelType] = None, - ) -> T2IAdapter: - if child_type is not None: - raise ValueError(f"T2I-Adapters do not have child models. Invalid child type: '{child_type}'.") - - model = None - for variant in ["fp16", None]: - try: - model = self.model_class.from_pretrained( - self.model_path, - torch_dtype=torch_dtype, - variant=variant, - ) - break - except Exception: - pass - if not model: - raise ModelNotFoundException() - - # Calculate a more accurate size after loading the model into memory. - self.model_size = calc_model_size_by_data(model) - return model - - @classproperty - def save_to_config(cls) -> bool: - return False - - @classmethod - def detect_format(cls, path: str): - if not os.path.exists(path): - raise ModelNotFoundException(f"Model not found at '{path}'.") - - if os.path.isdir(path): - if os.path.exists(os.path.join(path, "config.json")): - return T2IAdapterModelFormat.Diffusers - - raise InvalidModelException(f"Unsupported T2I-Adapter format: '{path}'.") - - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, - ) -> str: - format = cls.detect_format(model_path) - if format == T2IAdapterModelFormat.Diffusers: - return model_path - else: - raise ValueError(f"Unsupported format: '{format}'.") diff --git a/invokeai/backend/model_management/models/textual_inversion.py b/invokeai/backend/model_management/models/textual_inversion.py deleted file mode 100644 index 99358704b8d..00000000000 --- a/invokeai/backend/model_management/models/textual_inversion.py +++ /dev/null @@ -1,87 +0,0 @@ -import os -from typing import Optional - -import torch - -# TODO: naming -from ..lora import TextualInversionModel as TextualInversionModelRaw -from .base import ( - BaseModelType, - InvalidModelException, - ModelBase, - ModelConfigBase, - ModelNotFoundException, - ModelType, - SubModelType, - classproperty, -) - - -class TextualInversionModel(ModelBase): - # model_size: int - - class Config(ModelConfigBase): - model_format: None - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert model_type == ModelType.TextualInversion - super().__init__(model_path, base_model, model_type) - - self.model_size = os.path.getsize(self.model_path) - - def get_size(self, child_type: Optional[SubModelType] = None): - if child_type is not None: - raise Exception("There is no child models in textual inversion") - return self.model_size - - def get_model( - self, - torch_dtype: Optional[torch.dtype], - child_type: Optional[SubModelType] = None, - ): - if child_type is not None: - raise Exception("There is no child models in textual inversion") - - checkpoint_path = self.model_path - if os.path.isdir(checkpoint_path): - checkpoint_path = os.path.join(checkpoint_path, "learned_embeds.bin") - - if not os.path.exists(checkpoint_path): - raise ModelNotFoundException() - - model = TextualInversionModelRaw.from_checkpoint( - file_path=checkpoint_path, - dtype=torch_dtype, - ) - - self.model_size = model.embedding.nelement() * model.embedding.element_size() - return model - - @classproperty - def save_to_config(cls) -> bool: - return False - - @classmethod - def detect_format(cls, path: str): - if not os.path.exists(path): - raise ModelNotFoundException() - - if os.path.isdir(path): - if os.path.exists(os.path.join(path, "learned_embeds.bin")): - return None # diffusers-ti - - if os.path.isfile(path): - if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "bin"]): - return None - - raise InvalidModelException(f"Not a valid model: {path}") - - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, - ) -> str: - return model_path diff --git a/invokeai/backend/model_management/models/vae.py b/invokeai/backend/model_management/models/vae.py deleted file mode 100644 index 8cc37e67a73..00000000000 --- a/invokeai/backend/model_management/models/vae.py +++ /dev/null @@ -1,179 +0,0 @@ -import os -from enum import Enum -from pathlib import Path -from typing import Optional - -import safetensors -import torch -from omegaconf import OmegaConf - -from invokeai.app.services.config import InvokeAIAppConfig - -from .base import ( - BaseModelType, - EmptyConfigLoader, - InvalidModelException, - ModelBase, - ModelConfigBase, - ModelNotFoundException, - ModelType, - ModelVariantType, - SubModelType, - calc_model_size_by_data, - calc_model_size_by_fs, - classproperty, -) - - -class VaeModelFormat(str, Enum): - Checkpoint = "checkpoint" - Diffusers = "diffusers" - - -class VaeModel(ModelBase): - # vae_class: Type - # model_size: int - - class Config(ModelConfigBase): - model_format: VaeModelFormat - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert model_type == ModelType.Vae - super().__init__(model_path, base_model, model_type) - - try: - config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json") - # config = json.loads(os.path.join(self.model_path, "config.json")) - except Exception: - raise Exception("Invalid vae model! (config.json not found or invalid)") - - try: - vae_class_name = config.get("_class_name", "AutoencoderKL") - self.vae_class = self._hf_definition_to_type(["diffusers", vae_class_name]) - self.model_size = calc_model_size_by_fs(self.model_path) - except Exception: - raise Exception("Invalid vae model! (Unkown vae type)") - - def get_size(self, child_type: Optional[SubModelType] = None): - if child_type is not None: - raise Exception("There is no child models in vae model") - return self.model_size - - def get_model( - self, - torch_dtype: Optional[torch.dtype], - child_type: Optional[SubModelType] = None, - ): - if child_type is not None: - raise Exception("There is no child models in vae model") - - model = self.vae_class.from_pretrained( - self.model_path, - torch_dtype=torch_dtype, - ) - # calc more accurate size - self.model_size = calc_model_size_by_data(model) - return model - - @classproperty - def save_to_config(cls) -> bool: - return False - - @classmethod - def detect_format(cls, path: str): - if not os.path.exists(path): - raise ModelNotFoundException(f"Does not exist as local file: {path}") - - if os.path.isdir(path): - if os.path.exists(os.path.join(path, "config.json")): - return VaeModelFormat.Diffusers - - if os.path.isfile(path): - if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]): - return VaeModelFormat.Checkpoint - - raise InvalidModelException(f"Not a valid model: {path}") - - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, # empty config or config of parent model - base_model: BaseModelType, - ) -> str: - if cls.detect_format(model_path) == VaeModelFormat.Checkpoint: - return _convert_vae_ckpt_and_cache( - weights_path=model_path, - output_path=output_path, - base_model=base_model, - model_config=config, - ) - else: - return model_path - - -# TODO: rework -def _convert_vae_ckpt_and_cache( - weights_path: str, - output_path: str, - base_model: BaseModelType, - model_config: ModelConfigBase, -) -> str: - """ - Convert the VAE indicated in mconfig into a diffusers AutoencoderKL - object, cache it to disk, and return Path to converted - file. If already on disk then just returns Path. - """ - app_config = InvokeAIAppConfig.get_config() - weights_path = app_config.root_dir / weights_path - output_path = Path(output_path) - - """ - this size used only in when tiling enabled to separate input in tiles - sizes in configs from stable diffusion githubs(1 and 2) set to 256 - on huggingface it: - 1.5 - 512 - 1.5-inpainting - 256 - 2-inpainting - 512 - 2-depth - 256 - 2-base - 512 - 2 - 768 - 2.1-base - 768 - 2.1 - 768 - """ - image_size = 512 - - # return cached version if it exists - if output_path.exists(): - return output_path - - if base_model in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}: - from .stable_diffusion import _select_ckpt_config - - # all sd models use same vae settings - config_file = _select_ckpt_config(base_model, ModelVariantType.Normal) - else: - raise Exception(f"Vae conversion not supported for model type: {base_model}") - - # this avoids circular import error - from ..convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers - - if weights_path.suffix == ".safetensors": - checkpoint = safetensors.torch.load_file(weights_path, device="cpu") - else: - checkpoint = torch.load(weights_path, map_location="cpu") - - # sometimes weights are hidden under "state_dict", and sometimes not - if "state_dict" in checkpoint: - checkpoint = checkpoint["state_dict"] - - config = OmegaConf.load(app_config.root_path / config_file) - - vae_model = convert_ldm_vae_to_diffusers( - checkpoint=checkpoint, - vae_config=config, - image_size=image_size, - ) - vae_model.save_pretrained(output_path, safe_serialization=True) - return output_path diff --git a/invokeai/backend/model_management/seamless.py b/invokeai/backend/model_management/seamless.py deleted file mode 100644 index bfdf9e0c536..00000000000 --- a/invokeai/backend/model_management/seamless.py +++ /dev/null @@ -1,102 +0,0 @@ -from __future__ import annotations - -from contextlib import contextmanager -from typing import List, Union - -import torch.nn as nn -from diffusers.models import AutoencoderKL, UNet2DConditionModel - - -def _conv_forward_asymmetric(self, input, weight, bias): - """ - Patch for Conv2d._conv_forward that supports asymmetric padding - """ - working = nn.functional.pad(input, self.asymmetric_padding["x"], mode=self.asymmetric_padding_mode["x"]) - working = nn.functional.pad(working, self.asymmetric_padding["y"], mode=self.asymmetric_padding_mode["y"]) - return nn.functional.conv2d( - working, - weight, - bias, - self.stride, - nn.modules.utils._pair(0), - self.dilation, - self.groups, - ) - - -@contextmanager -def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]): - try: - to_restore = [] - - for m_name, m in model.named_modules(): - if isinstance(model, UNet2DConditionModel): - if ".attentions." in m_name: - continue - - if ".resnets." in m_name: - if ".conv2" in m_name: - continue - if ".conv_shortcut" in m_name: - continue - - """ - if isinstance(model, UNet2DConditionModel): - if False and ".upsamplers." in m_name: - continue - - if False and ".downsamplers." in m_name: - continue - - if True and ".resnets." in m_name: - if True and ".conv1" in m_name: - if False and "down_blocks" in m_name: - continue - if False and "mid_block" in m_name: - continue - if False and "up_blocks" in m_name: - continue - - if True and ".conv2" in m_name: - continue - - if True and ".conv_shortcut" in m_name: - continue - - if True and ".attentions." in m_name: - continue - - if False and m_name in ["conv_in", "conv_out"]: - continue - """ - - if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): - m.asymmetric_padding_mode = {} - m.asymmetric_padding = {} - m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant" - m.asymmetric_padding["x"] = ( - m._reversed_padding_repeated_twice[0], - m._reversed_padding_repeated_twice[1], - 0, - 0, - ) - m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant" - m.asymmetric_padding["y"] = ( - 0, - 0, - m._reversed_padding_repeated_twice[2], - m._reversed_padding_repeated_twice[3], - ) - - to_restore.append((m, m._conv_forward)) - m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d) - - yield - - finally: - for module, orig_conv_forward in to_restore: - module._conv_forward = orig_conv_forward - if hasattr(module, "asymmetric_padding_mode"): - del module.asymmetric_padding_mode - if hasattr(module, "asymmetric_padding"): - del module.asymmetric_padding diff --git a/invokeai/backend/model_manager/__init__.py b/invokeai/backend/model_manager/__init__.py index 0f16852c934..88356d04686 100644 --- a/invokeai/backend/model_manager/__init__.py +++ b/invokeai/backend/model_manager/__init__.py @@ -1,6 +1,6 @@ """Re-export frequently-used symbols from the Model Manager backend.""" - from .config import ( + AnyModel, AnyModelConfig, BaseModelType, InvalidModelConfigException, @@ -12,14 +12,17 @@ SchedulerPredictionType, SubModelType, ) +from .load import LoadedModel from .probe import ModelProbe from .search import ModelSearch __all__ = [ + "AnyModel", "AnyModelConfig", "BaseModelType", "ModelRepoVariant", "InvalidModelConfigException", + "LoadedModel", "ModelConfigFactory", "ModelFormat", "ModelProbe", @@ -29,3 +32,42 @@ "SchedulerPredictionType", "SubModelType", ] + +########## to help populate the openapi_schema with format enums for each config ########### +# This code is no longer necessary? +# leave it here just in case +# +# import inspect +# from enum import Enum +# from typing import Any, Iterable, Dict, get_args, Set +# def _expand(something: Any) -> Iterable[type]: +# if isinstance(something, type): +# yield something +# else: +# for x in get_args(something): +# for y in _expand(x): +# yield y + +# def _find_format(cls: type) -> Iterable[Enum]: +# if hasattr(inspect, "get_annotations"): +# fields = inspect.get_annotations(cls) +# else: +# fields = cls.__annotations__ +# if "format" in fields: +# for x in get_args(fields["format"]): +# yield x +# for parent_class in cls.__bases__: +# for x in _find_format(parent_class): +# yield x +# return None + +# def get_model_config_formats() -> Dict[str, Set[Enum]]: +# result: Dict[str, Set[Enum]] = {} +# for model_config in _expand(AnyModelConfig): +# for field in _find_format(model_config): +# if field is None: +# continue +# if not result.get(model_config.__qualname__): +# result[model_config.__qualname__] = set() +# result[model_config.__qualname__].add(field) +# return result diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 964cc19f196..bc4848b0a50 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -19,12 +19,21 @@ Validation errors will raise an InvalidModelConfigException error. """ +import time from enum import Enum from typing import Literal, Optional, Type, Union +import torch +from diffusers import ModelMixin from pydantic import BaseModel, ConfigDict, Field, TypeAdapter from typing_extensions import Annotated, Any, Dict +from ..raw_model import RawModel + +# ModelMixin is the base class for all diffusers and transformers models +# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime +AnyModel = Union[ModelMixin, RawModel, torch.nn.Module] + class InvalidModelConfigException(Exception): """Exception for when config parser doesn't recognized this combination of model type and format.""" @@ -102,7 +111,7 @@ class SchedulerPredictionType(str, Enum): class ModelRepoVariant(str, Enum): """Various hugging face variants on the diffusers format.""" - DEFAULT = "default" # model files without "fp16" or other qualifier + DEFAULT = "" # model files without "fp16" or other qualifier - empty str FP16 = "fp16" FP32 = "fp32" ONNX = "onnx" @@ -113,11 +122,11 @@ class ModelRepoVariant(str, Enum): class ModelConfigBase(BaseModel): """Base class for model configuration information.""" - path: str - name: str - base: BaseModelType - type: ModelType - format: ModelFormat + path: str = Field(description="filesystem path to the model file or directory") + name: str = Field(description="model name") + base: BaseModelType = Field(description="base model") + type: ModelType = Field(description="type of the model") + format: ModelFormat = Field(description="model format") key: str = Field(description="unique key for model", default="") original_hash: Optional[str] = Field( description="original fasthash of model contents", default=None @@ -125,8 +134,9 @@ class ModelConfigBase(BaseModel): current_hash: Optional[str] = Field( description="current fasthash of model contents", default=None ) # if model is converted or otherwise modified, this will hold updated hash - description: Optional[str] = Field(default=None) - source: Optional[str] = Field(description="Model download source (URL or repo_id)", default=None) + description: Optional[str] = Field(description="human readable description of the model", default=None) + source: Optional[str] = Field(description="model original source (path, URL or repo_id)", default=None) + last_modified: Optional[float] = Field(description="timestamp for modification time", default_factory=time.time) model_config = ConfigDict( use_enum_values=False, @@ -150,6 +160,7 @@ class _DiffusersConfig(ModelConfigBase): """Model config for diffusers-style models.""" format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers + repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT class LoRAConfig(ModelConfigBase): @@ -199,6 +210,8 @@ class _MainConfig(ModelConfigBase): vae: Optional[str] = Field(default=None) variant: ModelVariantType = ModelVariantType.Normal + prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon + upcast_attention: bool = False ztsnr_training: bool = False @@ -212,8 +225,6 @@ class MainDiffusersConfig(_DiffusersConfig, _MainConfig): """Model config for main diffusers models.""" type: Literal[ModelType.Main] = ModelType.Main - prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon - upcast_attention: bool = False class ONNXSD1Config(_MainConfig): @@ -237,10 +248,21 @@ class ONNXSD2Config(_MainConfig): upcast_attention: bool = True +class ONNXSDXLConfig(_MainConfig): + """Model config for ONNX format models based on sdxl.""" + + type: Literal[ModelType.ONNX] = ModelType.ONNX + format: Literal[ModelFormat.Onnx, ModelFormat.Olive] + # No yaml config file for ONNX, so these are part of config + base: Literal[BaseModelType.StableDiffusionXL] = BaseModelType.StableDiffusionXL + prediction_type: SchedulerPredictionType = SchedulerPredictionType.VPrediction + + class IPAdapterConfig(ModelConfigBase): """Model config for IP Adaptor format models.""" type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter + image_encoder_model_id: str format: Literal[ModelFormat.InvokeAI] @@ -258,7 +280,7 @@ class T2IConfig(ModelConfigBase): format: Literal[ModelFormat.Diffusers] -_ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator="base")] +_ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config, ONNXSDXLConfig], Field(discriminator="base")] _ControlNetConfig = Annotated[ Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig], Field(discriminator="format"), @@ -271,6 +293,7 @@ class T2IConfig(ModelConfigBase): _ONNXConfig, _VaeConfig, _ControlNetConfig, + # ModelConfigBase, LoRAConfig, TextualInversionConfig, IPAdapterConfig, @@ -280,6 +303,7 @@ class T2IConfig(ModelConfigBase): AnyModelConfigValidator = TypeAdapter(AnyModelConfig) + # IMPLEMENTATION NOTE: # The preferred alternative to the above is a discriminated Union as shown # below. However, it breaks FastAPI when used as the input Body parameter in a route. @@ -308,9 +332,10 @@ class ModelConfigFactory(object): @classmethod def make_config( cls, - model_data: Union[dict, AnyModelConfig], + model_data: Union[Dict[str, Any], AnyModelConfig], key: Optional[str] = None, - dest_class: Optional[Type] = None, + dest_class: Optional[Type[ModelConfigBase]] = None, + timestamp: Optional[float] = None, ) -> AnyModelConfig: """ Return the appropriate config object from raw dict values. @@ -321,12 +346,17 @@ def make_config( :param dest_class: The config class to be returned. If not provided, will be selected automatically. """ + model: Optional[ModelConfigBase] = None if isinstance(model_data, ModelConfigBase): model = model_data elif dest_class: - model = dest_class.validate_python(model_data) + model = dest_class.model_validate(model_data) else: - model = AnyModelConfigValidator.validate_python(model_data) + # mypy doesn't typecheck TypeAdapters well? + model = AnyModelConfigValidator.validate_python(model_data) # type: ignore + assert model is not None if key: model.key = key - return model + if timestamp: + model.last_modified = timestamp + return model # type: ignore diff --git a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py b/invokeai/backend/model_manager/convert_ckpt_to_diffusers.py similarity index 99% rename from invokeai/backend/model_management/convert_ckpt_to_diffusers.py rename to invokeai/backend/model_manager/convert_ckpt_to_diffusers.py index 6878218f679..6f5acd58329 100644 --- a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py +++ b/invokeai/backend/model_manager/convert_ckpt_to_diffusers.py @@ -57,10 +57,9 @@ ) from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend.model_manager import BaseModelType, ModelVariantType from invokeai.backend.util.logging import InvokeAILogger -from .models import BaseModelType, ModelVariantType - try: from omegaconf import OmegaConf from omegaconf.dictconfig import DictConfig @@ -1643,6 +1642,8 @@ def download_controlnet_from_original_ckpt( cross_attention_dim: Optional[bool] = None, scan_needed: bool = False, ) -> DiffusionPipeline: + from omegaconf import OmegaConf + if from_safetensors: from safetensors import safe_open @@ -1718,6 +1719,7 @@ def convert_ckpt_to_diffusers( """ pipe = download_from_original_stable_diffusion_ckpt(checkpoint_path, **kwargs) + # TO DO: save correct repo variant pipe.save_pretrained( dump_path, safe_serialization=use_safetensors, @@ -1736,4 +1738,5 @@ def convert_controlnet_to_diffusers( """ pipe = download_controlnet_from_original_ckpt(checkpoint_path, **kwargs) + # TO DO: save correct repo variant pipe.save_pretrained(dump_path, safe_serialization=True) diff --git a/invokeai/backend/model_management/libc_util.py b/invokeai/backend/model_manager/libc_util.py similarity index 100% rename from invokeai/backend/model_management/libc_util.py rename to invokeai/backend/model_manager/libc_util.py diff --git a/invokeai/backend/model_manager/load/__init__.py b/invokeai/backend/model_manager/load/__init__.py new file mode 100644 index 00000000000..a0421017db9 --- /dev/null +++ b/invokeai/backend/model_manager/load/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development Team +""" +Init file for the model loader. +""" +from importlib import import_module +from pathlib import Path + +from .convert_cache.convert_cache_default import ModelConvertCache +from .load_base import LoadedModel, ModelLoaderBase +from .load_default import ModelLoader +from .model_cache.model_cache_default import ModelCache +from .model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase + +# This registers the subclasses that implement loaders of specific model types +loaders = [x.stem for x in Path(Path(__file__).parent, "model_loaders").glob("*.py") if x.stem != "__init__"] +for module in loaders: + import_module(f"{__package__}.model_loaders.{module}") + +__all__ = [ + "LoadedModel", + "ModelCache", + "ModelConvertCache", + "ModelLoaderBase", + "ModelLoader", + "ModelLoaderRegistryBase", + "ModelLoaderRegistry", +] diff --git a/invokeai/backend/model_manager/load/convert_cache/__init__.py b/invokeai/backend/model_manager/load/convert_cache/__init__.py new file mode 100644 index 00000000000..5be56d2d584 --- /dev/null +++ b/invokeai/backend/model_manager/load/convert_cache/__init__.py @@ -0,0 +1,4 @@ +from .convert_cache_base import ModelConvertCacheBase +from .convert_cache_default import ModelConvertCache + +__all__ = ["ModelConvertCacheBase", "ModelConvertCache"] diff --git a/invokeai/backend/model_manager/load/convert_cache/convert_cache_base.py b/invokeai/backend/model_manager/load/convert_cache/convert_cache_base.py new file mode 100644 index 00000000000..6268c099a5f --- /dev/null +++ b/invokeai/backend/model_manager/load/convert_cache/convert_cache_base.py @@ -0,0 +1,27 @@ +""" +Disk-based converted model cache. +""" +from abc import ABC, abstractmethod +from pathlib import Path + + +class ModelConvertCacheBase(ABC): + @property + @abstractmethod + def max_size(self) -> float: + """Return the maximum size of this cache directory.""" + pass + + @abstractmethod + def make_room(self, size: float) -> None: + """ + Make sufficient room in the cache directory for a model of max_size. + + :param size: Size required (GB) + """ + pass + + @abstractmethod + def cache_path(self, key: str) -> Path: + """Return the path for a model with the indicated key.""" + pass diff --git a/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py b/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py new file mode 100644 index 00000000000..84f4f76299a --- /dev/null +++ b/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py @@ -0,0 +1,72 @@ +""" +Placeholder for convert cache implementation. +""" + +import shutil +from pathlib import Path + +from invokeai.backend.util import GIG, directory_size +from invokeai.backend.util.logging import InvokeAILogger + +from .convert_cache_base import ModelConvertCacheBase + + +class ModelConvertCache(ModelConvertCacheBase): + def __init__(self, cache_path: Path, max_size: float = 10.0): + """Initialize the convert cache with the base directory and a limit on its maximum size (in GBs).""" + if not cache_path.exists(): + cache_path.mkdir(parents=True) + self._cache_path = cache_path + self._max_size = max_size + + @property + def max_size(self) -> float: + """Return the maximum size of this cache directory (GB).""" + return self._max_size + + def cache_path(self, key: str) -> Path: + """Return the path for a model with the indicated key.""" + return self._cache_path / key + + def make_room(self, size: float) -> None: + """ + Make sufficient room in the cache directory for a model of max_size. + + :param size: Size required (GB) + """ + size_needed = directory_size(self._cache_path) + size + max_size = int(self.max_size) * GIG + logger = InvokeAILogger.get_logger() + + if size_needed <= max_size: + return + + logger.debug( + f"Convert cache has gotten too large {(size_needed / GIG):4.2f} > {(max_size / GIG):4.2f}G.. Trimming." + ) + + # For this to work, we make the assumption that the directory contains + # a 'model_index.json', 'unet/config.json' file, or a 'config.json' file at top level. + # This should be true for any diffusers model. + def by_atime(path: Path) -> float: + for config in ["model_index.json", "unet/config.json", "config.json"]: + sentinel = path / config + if sentinel.exists(): + return sentinel.stat().st_atime + + # no sentinel file found! - pick the most recent file in the directory + try: + atimes = sorted([x.stat().st_atime for x in path.iterdir() if x.is_file()], reverse=True) + return atimes[0] + except IndexError: + return 0.0 + + # sort by last access time - least accessed files will be at the end + lru_models = sorted(self._cache_path.iterdir(), key=by_atime, reverse=True) + logger.debug(f"cached models in descending atime order: {lru_models}") + while size_needed > max_size and len(lru_models) > 0: + next_victim = lru_models.pop() + victim_size = directory_size(next_victim) + logger.debug(f"Removing cached converted model {next_victim} to free {victim_size / GIG} GB") + shutil.rmtree(next_victim) + size_needed -= victim_size diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py new file mode 100644 index 00000000000..0e085792547 --- /dev/null +++ b/invokeai/backend/model_manager/load/load_base.py @@ -0,0 +1,77 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +""" +Base class for model loading in InvokeAI. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from logging import Logger +from pathlib import Path +from typing import Any, Optional + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend.model_manager.config import ( + AnyModel, + AnyModelConfig, + SubModelType, +) +from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase +from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase + + +@dataclass +class LoadedModel: + """Context manager object that mediates transfer from RAM<->VRAM.""" + + config: AnyModelConfig + _locker: ModelLockerBase + + def __enter__(self) -> AnyModel: + """Context entry.""" + self._locker.lock() + return self.model + + def __exit__(self, *args: Any, **kwargs: Any) -> None: + """Context exit.""" + self._locker.unlock() + + @property + def model(self) -> AnyModel: + """Return the model without locking it.""" + return self._locker.model + + +class ModelLoaderBase(ABC): + """Abstract base class for loading models into RAM/VRAM.""" + + @abstractmethod + def __init__( + self, + app_config: InvokeAIAppConfig, + logger: Logger, + ram_cache: ModelCacheBase[AnyModel], + convert_cache: ModelConvertCacheBase, + ): + """Initialize the loader.""" + pass + + @abstractmethod + def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: + """ + Return a model given its confguration. + + Given a model identified in the model configuration backend, + return a ModelInfo object that can be used to retrieve the model. + + :param model_config: Model configuration, as returned by ModelConfigRecordStore + :param submodel_type: an ModelType enum indicating the portion of + the model to retrieve (e.g. ModelType.Vae) + """ + pass + + @abstractmethod + def get_size_fs( + self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None + ) -> int: + """Return size in bytes of the model, calculated before loading.""" + pass diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py new file mode 100644 index 00000000000..642cffaf4be --- /dev/null +++ b/invokeai/backend/model_manager/load/load_default.py @@ -0,0 +1,136 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Default implementation of model loading in InvokeAI.""" + +from logging import Logger +from pathlib import Path +from typing import Optional, Tuple + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend.model_manager import ( + AnyModel, + AnyModelConfig, + InvalidModelConfigException, + ModelRepoVariant, + SubModelType, +) +from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase +from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase +from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase +from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs +from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init +from invokeai.backend.util.devices import choose_torch_device, torch_dtype + + +# TO DO: The loader is not thread safe! +class ModelLoader(ModelLoaderBase): + """Default implementation of ModelLoaderBase.""" + + def __init__( + self, + app_config: InvokeAIAppConfig, + logger: Logger, + ram_cache: ModelCacheBase[AnyModel], + convert_cache: ModelConvertCacheBase, + ): + """Initialize the loader.""" + self._app_config = app_config + self._logger = logger + self._ram_cache = ram_cache + self._convert_cache = convert_cache + self._torch_dtype = torch_dtype(choose_torch_device(), app_config) + + def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: + """ + Return a model given its configuration. + + Given a model's configuration as returned by the ModelRecordConfigStore service, + return a LoadedModel object that can be used for inference. + + :param model config: Configuration record for this model + :param submodel_type: an ModelType enum indicating the portion of + the model to retrieve (e.g. ModelType.Vae) + """ + if model_config.type == "main" and not submodel_type: + raise InvalidModelConfigException("submodel_type is required when loading a main model") + + model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type) + + if not model_path.exists(): + raise InvalidModelConfigException(f"Files for model '{model_config.name}' not found at {model_path}") + + model_path = self._convert_if_needed(model_config, model_path, submodel_type) + locker = self._load_if_needed(model_config, model_path, submodel_type) + return LoadedModel(config=model_config, _locker=locker) + + def _get_model_path( + self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None + ) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]: + model_base = self._app_config.models_path + result = (model_base / config.path).resolve(), config, submodel_type + return result + + def _convert_if_needed( + self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None + ) -> Path: + cache_path: Path = self._convert_cache.cache_path(config.key) + + if not self._needs_conversion(config, model_path, cache_path): + return cache_path if cache_path.exists() else model_path + + self._convert_cache.make_room(self.get_size_fs(config, model_path, submodel_type)) + return self._convert_model(config, model_path, cache_path) + + def _needs_conversion(self, config: AnyModelConfig, model_path: Path, cache_path: Path) -> bool: + return False + + def _load_if_needed( + self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None + ) -> ModelLockerBase: + # TO DO: This is not thread safe! + try: + return self._ram_cache.get(config.key, submodel_type) + except IndexError: + pass + + model_variant = getattr(config, "repo_variant", None) + self._ram_cache.make_room(self.get_size_fs(config, model_path, submodel_type)) + + # This is where the model is actually loaded! + with skip_torch_weight_init(): + loaded_model = self._load_model(model_path, model_variant=model_variant, submodel_type=submodel_type) + + self._ram_cache.put( + config.key, + submodel_type=submodel_type, + model=loaded_model, + size=calc_model_size_by_data(loaded_model), + ) + + return self._ram_cache.get( + key=config.key, + submodel_type=submodel_type, + stats_name=":".join([config.base, config.type, config.name, (submodel_type or "")]), + ) + + def get_size_fs( + self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None + ) -> int: + """Get the size of the model on disk.""" + return calc_model_size_by_fs( + model_path=model_path, + subfolder=submodel_type.value if submodel_type else None, + variant=config.repo_variant if hasattr(config, "repo_variant") else None, + ) + + # This needs to be implemented in subclasses that handle checkpoints + def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path: + raise NotImplementedError + + # This needs to be implemented in the subclass + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + raise NotImplementedError diff --git a/invokeai/backend/model_management/memory_snapshot.py b/invokeai/backend/model_manager/load/memory_snapshot.py similarity index 94% rename from invokeai/backend/model_management/memory_snapshot.py rename to invokeai/backend/model_manager/load/memory_snapshot.py index fe54af191ce..195e39361b4 100644 --- a/invokeai/backend/model_management/memory_snapshot.py +++ b/invokeai/backend/model_manager/load/memory_snapshot.py @@ -3,8 +3,9 @@ import psutil import torch +from typing_extensions import Self -from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2 +from ..util.libc_util import LibcUtil, Struct_mallinfo2 GB = 2**30 # 1 GB @@ -27,7 +28,7 @@ def __init__(self, process_ram: int, vram: Optional[int], malloc_info: Optional[ self.malloc_info = malloc_info @classmethod - def capture(cls, run_garbage_collector: bool = True): + def capture(cls, run_garbage_collector: bool = True) -> Self: """Capture and return a MemorySnapshot. Note: This function has significant overhead, particularly if `run_garbage_collector == True`. @@ -67,7 +68,7 @@ def capture(cls, run_garbage_collector: bool = True): def get_pretty_snapshot_diff(snapshot_1: Optional[MemorySnapshot], snapshot_2: Optional[MemorySnapshot]) -> str: """Get a pretty string describing the difference between two `MemorySnapshot`s.""" - def get_msg_line(prefix: str, val1: int, val2: int): + def get_msg_line(prefix: str, val1: int, val2: int) -> str: diff = val2 - val1 return f"{prefix: <30} ({(diff/GB):+5.3f}): {(val1/GB):5.3f}GB -> {(val2/GB):5.3f}GB\n" diff --git a/invokeai/backend/model_manager/load/model_cache/__init__.py b/invokeai/backend/model_manager/load/model_cache/__init__.py new file mode 100644 index 00000000000..32c682d0424 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/__init__.py @@ -0,0 +1,6 @@ +"""Init file for ModelCache.""" + +from .model_cache_base import ModelCacheBase, CacheStats # noqa F401 +from .model_cache_default import ModelCache # noqa F401 + +_all__ = ["ModelCacheBase", "ModelCache", "CacheStats"] diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py new file mode 100644 index 00000000000..4a4a3c7d299 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py @@ -0,0 +1,188 @@ +# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team +# TODO: Add Stalker's proper name to copyright +""" +Manage a RAM cache of diffusion/transformer models for fast switching. +They are moved between GPU VRAM and CPU RAM as necessary. If the cache +grows larger than a preset maximum, then the least recently used +model will be cleared and (re)loaded from disk when next needed. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from logging import Logger +from typing import Dict, Generic, Optional, TypeVar + +import torch + +from invokeai.backend.model_manager.config import AnyModel, SubModelType + + +class ModelLockerBase(ABC): + """Base class for the model locker used by the loader.""" + + @abstractmethod + def lock(self) -> AnyModel: + """Lock the contained model and move it into VRAM.""" + pass + + @abstractmethod + def unlock(self) -> None: + """Unlock the contained model, and remove it from VRAM.""" + pass + + @property + @abstractmethod + def model(self) -> AnyModel: + """Return the model.""" + pass + + +T = TypeVar("T") + + +@dataclass +class CacheRecord(Generic[T]): + """Elements of the cache.""" + + key: str + model: T + size: int + loaded: bool = False + _locks: int = 0 + + def lock(self) -> None: + """Lock this record.""" + self._locks += 1 + + def unlock(self) -> None: + """Unlock this record.""" + self._locks -= 1 + assert self._locks >= 0 + + @property + def locked(self) -> bool: + """Return true if record is locked.""" + return self._locks > 0 + + +@dataclass +class CacheStats(object): + """Collect statistics on cache performance.""" + + hits: int = 0 # cache hits + misses: int = 0 # cache misses + high_watermark: int = 0 # amount of cache used + in_cache: int = 0 # number of models in cache + cleared: int = 0 # number of models cleared to make space + cache_size: int = 0 # total size of cache + loaded_model_sizes: Dict[str, int] = field(default_factory=dict) + + +class ModelCacheBase(ABC, Generic[T]): + """Virtual base class for RAM model cache.""" + + @property + @abstractmethod + def storage_device(self) -> torch.device: + """Return the storage device (e.g. "CPU" for RAM).""" + pass + + @property + @abstractmethod + def execution_device(self) -> torch.device: + """Return the exection device (e.g. "cuda" for VRAM).""" + pass + + @property + @abstractmethod + def lazy_offloading(self) -> bool: + """Return true if the cache is configured to lazily offload models in VRAM.""" + pass + + @property + @abstractmethod + def max_cache_size(self) -> float: + """Return true if the cache is configured to lazily offload models in VRAM.""" + pass + + @abstractmethod + def offload_unlocked_models(self, size_required: int) -> None: + """Offload from VRAM any models not actively in use.""" + pass + + @abstractmethod + def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], device: torch.device) -> None: + """Move model into the indicated device.""" + pass + + @property + @abstractmethod + def stats(self) -> CacheStats: + """Return collected CacheStats object.""" + pass + + @stats.setter + @abstractmethod + def stats(self, stats: CacheStats) -> None: + """Set the CacheStats object for collectin cache statistics.""" + pass + + @property + @abstractmethod + def logger(self) -> Logger: + """Return the logger used by the cache.""" + pass + + @abstractmethod + def make_room(self, size: int) -> None: + """Make enough room in the cache to accommodate a new model of indicated size.""" + pass + + @abstractmethod + def put( + self, + key: str, + model: T, + size: int, + submodel_type: Optional[SubModelType] = None, + ) -> None: + """Store model under key and optional submodel_type.""" + pass + + @abstractmethod + def get( + self, + key: str, + submodel_type: Optional[SubModelType] = None, + stats_name: Optional[str] = None, + ) -> ModelLockerBase: + """ + Retrieve model using key and optional submodel_type. + + :param key: Opaque model key + :param submodel_type: Type of the submodel to fetch + :param stats_name: A human-readable id for the model for the purposes of + stats reporting. + + This may raise an IndexError if the model is not in the cache. + """ + pass + + @abstractmethod + def exists( + self, + key: str, + submodel_type: Optional[SubModelType] = None, + ) -> bool: + """Return true if the model identified by key and submodel_type is in the cache.""" + pass + + @abstractmethod + def cache_size(self) -> int: + """Get the total size of the models currently cached.""" + pass + + @abstractmethod + def print_cuda_stats(self) -> None: + """Log debugging information on CUDA usage.""" + pass diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py new file mode 100644 index 00000000000..02ce1266c75 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -0,0 +1,407 @@ +# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team +# TODO: Add Stalker's proper name to copyright +""" +Manage a RAM cache of diffusion/transformer models for fast switching. +They are moved between GPU VRAM and CPU RAM as necessary. If the cache +grows larger than a preset maximum, then the least recently used +model will be cleared and (re)loaded from disk when next needed. + +The cache returns context manager generators designed to load the +model into the GPU within the context, and unload outside the +context. Use like this: + + cache = ModelCache(max_cache_size=7.5) + with cache.get_model('runwayml/stable-diffusion-1-5') as SD1, + cache.get_model('stabilityai/stable-diffusion-2') as SD2: + do_something_in_GPU(SD1,SD2) + + +""" + +import gc +import logging +import math +import sys +import time +from contextlib import suppress +from logging import Logger +from typing import Dict, List, Optional + +import torch + +from invokeai.backend.model_manager import AnyModel, SubModelType +from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff +from invokeai.backend.util.devices import choose_torch_device +from invokeai.backend.util.logging import InvokeAILogger + +from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase +from .model_locker import ModelLocker + +if choose_torch_device() == torch.device("mps"): + from torch import mps + +# Maximum size of the cache, in gigs +# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously +DEFAULT_MAX_CACHE_SIZE = 6.0 + +# amount of GPU memory to hold in reserve for use by generations (GB) +DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75 + +# actual size of a gig +GIG = 1073741824 + +# Size of a MB in bytes. +MB = 2**20 + + +class ModelCache(ModelCacheBase[AnyModel]): + """Implementation of ModelCacheBase.""" + + def __init__( + self, + max_cache_size: float = DEFAULT_MAX_CACHE_SIZE, + max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE, + execution_device: torch.device = torch.device("cuda"), + storage_device: torch.device = torch.device("cpu"), + precision: torch.dtype = torch.float16, + sequential_offload: bool = False, + lazy_offloading: bool = True, + sha_chunksize: int = 16777216, + log_memory_usage: bool = False, + logger: Optional[Logger] = None, + ): + """ + Initialize the model RAM cache. + + :param max_cache_size: Maximum size of the RAM cache [6.0 GB] + :param execution_device: Torch device to load active model into [torch.device('cuda')] + :param storage_device: Torch device to save inactive model in [torch.device('cpu')] + :param precision: Precision for loaded models [torch.float16] + :param lazy_offloading: Keep model in VRAM until another model needs to be loaded + :param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially + :param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache + operation, and the result will be logged (at debug level). There is a time cost to capturing the memory + snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's + behaviour. + """ + # allow lazy offloading only when vram cache enabled + self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0 + self._precision: torch.dtype = precision + self._max_cache_size: float = max_cache_size + self._max_vram_cache_size: float = max_vram_cache_size + self._execution_device: torch.device = execution_device + self._storage_device: torch.device = storage_device + self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__) + self._log_memory_usage = log_memory_usage or self._logger.level == logging.DEBUG + # used for stats collection + self._stats: Optional[CacheStats] = None + + self._cached_models: Dict[str, CacheRecord[AnyModel]] = {} + self._cache_stack: List[str] = [] + + @property + def logger(self) -> Logger: + """Return the logger used by the cache.""" + return self._logger + + @property + def lazy_offloading(self) -> bool: + """Return true if the cache is configured to lazily offload models in VRAM.""" + return self._lazy_offloading + + @property + def storage_device(self) -> torch.device: + """Return the storage device (e.g. "CPU" for RAM).""" + return self._storage_device + + @property + def execution_device(self) -> torch.device: + """Return the exection device (e.g. "cuda" for VRAM).""" + return self._execution_device + + @property + def max_cache_size(self) -> float: + """Return the cap on cache size.""" + return self._max_cache_size + + @property + def stats(self) -> Optional[CacheStats]: + """Return collected CacheStats object.""" + return self._stats + + @stats.setter + def stats(self, stats: CacheStats) -> None: + """Set the CacheStats object for collectin cache statistics.""" + self._stats = stats + + def cache_size(self) -> int: + """Get the total size of the models currently cached.""" + total = 0 + for cache_record in self._cached_models.values(): + total += cache_record.size + return total + + def exists( + self, + key: str, + submodel_type: Optional[SubModelType] = None, + ) -> bool: + """Return true if the model identified by key and submodel_type is in the cache.""" + key = self._make_cache_key(key, submodel_type) + return key in self._cached_models + + def put( + self, + key: str, + model: AnyModel, + size: int, + submodel_type: Optional[SubModelType] = None, + ) -> None: + """Store model under key and optional submodel_type.""" + key = self._make_cache_key(key, submodel_type) + assert key not in self._cached_models + + cache_record = CacheRecord(key, model, size) + self._cached_models[key] = cache_record + self._cache_stack.append(key) + + def get( + self, + key: str, + submodel_type: Optional[SubModelType] = None, + stats_name: Optional[str] = None, + ) -> ModelLockerBase: + """ + Retrieve model using key and optional submodel_type. + + :param key: Opaque model key + :param submodel_type: Type of the submodel to fetch + :param stats_name: A human-readable id for the model for the purposes of + stats reporting. + + This may raise an IndexError if the model is not in the cache. + """ + key = self._make_cache_key(key, submodel_type) + if key in self._cached_models: + if self.stats: + self.stats.hits += 1 + else: + if self.stats: + self.stats.misses += 1 + raise IndexError(f"The model with key {key} is not in the cache.") + + cache_entry = self._cached_models[key] + + # more stats + if self.stats: + stats_name = stats_name or key + self.stats.cache_size = int(self._max_cache_size * GIG) + self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size()) + self.stats.in_cache = len(self._cached_models) + self.stats.loaded_model_sizes[stats_name] = max( + self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size + ) + + # this moves the entry to the top (right end) of the stack + with suppress(Exception): + self._cache_stack.remove(key) + self._cache_stack.append(key) + return ModelLocker( + cache=self, + cache_entry=cache_entry, + ) + + def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]: + if self._log_memory_usage: + return MemorySnapshot.capture() + return None + + def _make_cache_key(self, model_key: str, submodel_type: Optional[SubModelType] = None) -> str: + if submodel_type: + return f"{model_key}:{submodel_type.value}" + else: + return model_key + + def offload_unlocked_models(self, size_required: int) -> None: + """Move any unused models from VRAM.""" + reserved = self._max_vram_cache_size * GIG + vram_in_use = torch.cuda.memory_allocated() + size_required + self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB") + for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size): + if vram_in_use <= reserved: + break + if not cache_entry.loaded: + continue + if not cache_entry.locked: + self.move_model_to_device(cache_entry, self.storage_device) + cache_entry.loaded = False + vram_in_use = torch.cuda.memory_allocated() + size_required + self.logger.debug( + f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB" + ) + + torch.cuda.empty_cache() + if choose_torch_device() == torch.device("mps"): + mps.empty_cache() + + def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None: + """Move model into the indicated device.""" + # These attributes are not in the base ModelMixin class but in various derived classes. + # Some models don't have these attributes, in which case they run in RAM/CPU. + self.logger.debug(f"Called to move {cache_entry.key} to {target_device}") + if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")): + return + + source_device = cache_entry.model.device + + # Note: We compare device types only so that 'cuda' == 'cuda:0'. + # This would need to be revised to support multi-GPU. + if torch.device(source_device).type == torch.device(target_device).type: + return + + start_model_to_time = time.time() + snapshot_before = self._capture_memory_snapshot() + cache_entry.model.to(target_device) + snapshot_after = self._capture_memory_snapshot() + end_model_to_time = time.time() + self.logger.debug( + f"Moved model '{cache_entry.key}' from {source_device} to" + f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s." + f"Estimated model size: {(cache_entry.size/GIG):.3f} GB." + f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" + ) + + if ( + snapshot_before is not None + and snapshot_after is not None + and snapshot_before.vram is not None + and snapshot_after.vram is not None + ): + vram_change = abs(snapshot_before.vram - snapshot_after.vram) + + # If the estimated model size does not match the change in VRAM, log a warning. + if not math.isclose( + vram_change, + cache_entry.size, + rel_tol=0.1, + abs_tol=10 * MB, + ): + self.logger.debug( + f"Moving model '{cache_entry.key}' from {source_device} to" + f" {target_device} caused an unexpected change in VRAM usage. The model's" + " estimated size may be incorrect. Estimated model size:" + f" {(cache_entry.size/GIG):.3f} GB.\n" + f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" + ) + + def print_cuda_stats(self) -> None: + """Log CUDA diagnostics.""" + vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG) + ram = "%4.2fG" % (self.cache_size() / GIG) + + in_ram_models = 0 + in_vram_models = 0 + locked_in_vram_models = 0 + for cache_record in self._cached_models.values(): + if hasattr(cache_record.model, "device"): + if cache_record.model.device == self.storage_device: + in_ram_models += 1 + else: + in_vram_models += 1 + if cache_record.locked: + locked_in_vram_models += 1 + + self.logger.debug( + f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) =" + f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})" + ) + + def make_room(self, model_size: int) -> None: + """Make enough room in the cache to accommodate a new model of indicated size.""" + # calculate how much memory this model will require + # multiplier = 2 if self.precision==torch.float32 else 1 + bytes_needed = model_size + maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes + current_size = self.cache_size() + + if current_size + bytes_needed > maximum_size: + self.logger.debug( + f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional" + f" {(bytes_needed/GIG):.2f} GB" + ) + + self.logger.debug(f"Before making_room: cached_models={len(self._cached_models)}") + + pos = 0 + models_cleared = 0 + while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack): + model_key = self._cache_stack[pos] + cache_entry = self._cached_models[model_key] + + refs = sys.getrefcount(cache_entry.model) + + # HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly + # going against the advice in the Python docs by using `gc.get_referrers(...)` in this way: + # https://docs.python.org/3/library/gc.html#gc.get_referrers + + # manualy clear local variable references of just finished function calls + # for some reason python don't want to collect it even by gc.collect() immidiately + if refs > 2: + while True: + cleared = False + for referrer in gc.get_referrers(cache_entry.model): + if type(referrer).__name__ == "frame": + # RuntimeError: cannot clear an executing frame + with suppress(RuntimeError): + referrer.clear() + cleared = True + # break + + # repeat if referrers changes(due to frame clear), else exit loop + if cleared: + gc.collect() + else: + break + + device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None + self.logger.debug( + f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}," + f" refs: {refs}" + ) + + # Expected refs: + # 1 from cache_entry + # 1 from getrefcount function + # 1 from onnx runtime object + if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2): + self.logger.debug( + f"Removing {model_key} from RAM cache to free at least {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)" + ) + current_size -= cache_entry.size + models_cleared += 1 + del self._cache_stack[pos] + del self._cached_models[model_key] + del cache_entry + + else: + pos += 1 + + if models_cleared > 0: + # There would likely be some 'garbage' to be collected regardless of whether a model was cleared or not, but + # there is a significant time cost to calling `gc.collect()`, so we want to use it sparingly. (The time cost + # is high even if no garbage gets collected.) + # + # Calling gc.collect(...) when a model is cleared seems like a good middle-ground: + # - If models had to be cleared, it's a signal that we are close to our memory limit. + # - If models were cleared, there's a good chance that there's a significant amount of garbage to be + # collected. + # + # Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up + # immediately when their reference count hits 0. + gc.collect() + + torch.cuda.empty_cache() + if choose_torch_device() == torch.device("mps"): + mps.empty_cache() + + self.logger.debug(f"After making room: cached_models={len(self._cached_models)}") diff --git a/invokeai/backend/model_manager/load/model_cache/model_locker.py b/invokeai/backend/model_manager/load/model_cache/model_locker.py new file mode 100644 index 00000000000..7a5fdd4284b --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/model_locker.py @@ -0,0 +1,59 @@ +""" +Base class and implementation of a class that moves models in and out of VRAM. +""" + +from invokeai.backend.model_manager import AnyModel + +from .model_cache_base import CacheRecord, ModelCacheBase, ModelLockerBase + + +class ModelLocker(ModelLockerBase): + """Internal class that mediates movement in and out of GPU.""" + + def __init__(self, cache: ModelCacheBase[AnyModel], cache_entry: CacheRecord[AnyModel]): + """ + Initialize the model locker. + + :param cache: The ModelCache object + :param cache_entry: The entry in the model cache + """ + self._cache = cache + self._cache_entry = cache_entry + + @property + def model(self) -> AnyModel: + """Return the model without moving it around.""" + return self._cache_entry.model + + def lock(self) -> AnyModel: + """Move the model into the execution device (GPU) and lock it.""" + if not hasattr(self.model, "to"): + return self.model + + # NOTE that the model has to have the to() method in order for this code to move it into GPU! + self._cache_entry.lock() + + try: + if self._cache.lazy_offloading: + self._cache.offload_unlocked_models(self._cache_entry.size) + + self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device) + self._cache_entry.loaded = True + + self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}") + self._cache.print_cuda_stats() + + except Exception: + self._cache_entry.unlock() + raise + return self.model + + def unlock(self) -> None: + """Call upon exit from context.""" + if not hasattr(self.model, "to"): + return + + self._cache_entry.unlock() + if not self._cache.lazy_offloading: + self._cache.offload_unlocked_models(self._cache_entry.size) + self._cache.print_cuda_stats() diff --git a/invokeai/backend/model_manager/load/model_loader_registry.py b/invokeai/backend/model_manager/load/model_loader_registry.py new file mode 100644 index 00000000000..ce1110e749b --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loader_registry.py @@ -0,0 +1,122 @@ +# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team +""" +This module implements a system in which model loaders register the +type, base and format of models that they know how to load. + +Use like this: + + cls, model_config, submodel_type = ModelLoaderRegistry.get_implementation(model_config, submodel_type) # type: ignore + loaded_model = cls( + app_config=app_config, + logger=logger, + ram_cache=ram_cache, + convert_cache=convert_cache + ).load_model(model_config, submodel_type) + +""" +import hashlib +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Callable, Dict, Optional, Tuple, Type + +from ..config import ( + AnyModelConfig, + BaseModelType, + ModelConfigBase, + ModelFormat, + ModelType, + SubModelType, + VaeCheckpointConfig, + VaeDiffusersConfig, +) +from . import ModelLoaderBase + + +class ModelLoaderRegistryBase(ABC): + """This class allows model loaders to register their type, base and format.""" + + @classmethod + @abstractmethod + def register( + cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any + ) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]: + """Define a decorator which registers the subclass of loader.""" + + @classmethod + @abstractmethod + def get_implementation( + cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] + ) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]: + """ + Get subclass of ModelLoaderBase registered to handle base and type. + + Parameters: + :param config: Model configuration record, as returned by ModelRecordService + :param submodel_type: Submodel to fetch (main models only) + :return: tuple(loader_class, model_config, submodel_type) + + Note that the returned model config may be different from one what passed + in, in the event that a submodel type is provided. + """ + + +class ModelLoaderRegistry: + """ + This class allows model loaders to register their type, base and format. + """ + + _registry: Dict[str, Type[ModelLoaderBase]] = {} + + @classmethod + def register( + cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any + ) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]: + """Define a decorator which registers the subclass of loader.""" + + def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]: + key = cls._to_registry_key(base, type, format) + if key in cls._registry: + raise Exception( + f"{subclass.__name__} is trying to register as a loader for {base}/{type}/{format}, but this type of model has already been registered by {cls._registry[key].__name__}" + ) + cls._registry[key] = subclass + return subclass + + return decorator + + @classmethod + def get_implementation( + cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] + ) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]: + """Get subclass of ModelLoaderBase registered to handle base and type.""" + # We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned + conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type) + + key1 = cls._to_registry_key(conf2.base, conf2.type, conf2.format) # for a specific base type + key2 = cls._to_registry_key(BaseModelType.Any, conf2.type, conf2.format) # with wildcard Any + implementation = cls._registry.get(key1) or cls._registry.get(key2) + if not implementation: + raise NotImplementedError( + f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}" + ) + return implementation, conf2, submodel_type + + @classmethod + def _handle_subtype_overrides( + cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] + ) -> Tuple[ModelConfigBase, Optional[SubModelType]]: + if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None: + model_path = Path(config.vae) + config_class = ( + VaeCheckpointConfig if model_path.suffix in [".pt", ".safetensors", ".ckpt"] else VaeDiffusersConfig + ) + hash = hashlib.md5(model_path.as_posix().encode("utf-8")).hexdigest() + new_conf = config_class(path=model_path.as_posix(), name=model_path.stem, base=config.base, key=hash) + submodel_type = None + else: + new_conf = config + return new_conf, submodel_type + + @staticmethod + def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) -> str: + return "-".join([base.value, type.value, format.value]) diff --git a/invokeai/backend/model_manager/load/model_loaders/__init__.py b/invokeai/backend/model_manager/load/model_loaders/__init__.py new file mode 100644 index 00000000000..962cba54811 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/__init__.py @@ -0,0 +1,3 @@ +""" +Init file for model_loaders. +""" diff --git a/invokeai/backend/model_manager/load/model_loaders/controlnet.py b/invokeai/backend/model_manager/load/model_loaders/controlnet.py new file mode 100644 index 00000000000..43393f5a847 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/controlnet.py @@ -0,0 +1,62 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Class for ControlNet model loading in InvokeAI.""" + +from pathlib import Path + +import safetensors +import torch + +from invokeai.backend.model_manager import ( + AnyModelConfig, + BaseModelType, + ModelFormat, + ModelType, +) +from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers + +from .. import ModelLoaderRegistry +from .generic_diffusers import GenericDiffusersLoader + + +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers) +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint) +class ControlnetLoader(GenericDiffusersLoader): + """Class to load ControlNet models.""" + + def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool: + if config.format != ModelFormat.Checkpoint: + return False + elif ( + dest_path.exists() + and (dest_path / "config.json").stat().st_mtime >= (config.last_modified or 0.0) + and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime + ): + return False + else: + return True + + def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path: + if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}: + raise Exception(f"Vae conversion not supported for model type: {config.base}") + else: + assert hasattr(config, "config") + config_file = config.config + + if model_path.suffix == ".safetensors": + checkpoint = safetensors.torch.load_file(model_path, device="cpu") + else: + checkpoint = torch.load(model_path, map_location="cpu") + + # sometimes weights are hidden under "state_dict", and sometimes not + if "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + convert_controlnet_to_diffusers( + model_path, + output_path, + original_config_file=self._app_config.root_path / config_file, + image_size=512, + scan_needed=True, + from_safetensors=model_path.suffix == ".safetensors", + ) + return output_path diff --git a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py new file mode 100644 index 00000000000..9a9b25aec53 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py @@ -0,0 +1,90 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Class for simple diffusers model loading in InvokeAI.""" + +import sys +from pathlib import Path +from typing import Any, Dict, Optional + +from diffusers import ConfigMixin, ModelMixin + +from invokeai.backend.model_manager import ( + AnyModel, + BaseModelType, + InvalidModelConfigException, + ModelFormat, + ModelRepoVariant, + ModelType, + SubModelType, +) + +from .. import ModelLoader, ModelLoaderRegistry + + +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers) +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers) +class GenericDiffusersLoader(ModelLoader): + """Class to load simple diffusers models.""" + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + model_class = self.get_hf_load_class(model_path) + if submodel_type is not None: + raise Exception(f"There are no submodels in models of type {model_class}") + variant = model_variant.value if model_variant else None + result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant) # type: ignore + return result + + # TO DO: Add exception handling + def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin: + """Given the model path and submodel, returns the diffusers ModelMixin subclass needed to load.""" + if submodel_type: + try: + config = self._load_diffusers_config(model_path, config_name="model_index.json") + module, class_name = config[submodel_type.value] + result = self._hf_definition_to_type(module=module, class_name=class_name) + except KeyError as e: + raise InvalidModelConfigException( + f'The "{submodel_type}" submodel is not available for this model.' + ) from e + else: + try: + config = self._load_diffusers_config(model_path, config_name="config.json") + class_name = config.get("_class_name", None) + if class_name: + result = self._hf_definition_to_type(module="diffusers", class_name=class_name) + if config.get("model_type", None) == "clip_vision_model": + class_name = config.get("architectures") + assert class_name is not None + result = self._hf_definition_to_type(module="transformers", class_name=class_name[0]) + if not class_name: + raise InvalidModelConfigException("Unable to decifer Load Class based on given config.json") + except KeyError as e: + raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e + return result + + # TO DO: Add exception handling + def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type + if module in ["diffusers", "transformers"]: + res_type = sys.modules[module] + else: + res_type = sys.modules["diffusers"].pipelines + result: ModelMixin = getattr(res_type, class_name) + return result + + def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]: + return ConfigLoader.load_config(model_path, config_name=config_name) + + +class ConfigLoader(ConfigMixin): + """Subclass of ConfigMixin for loading diffusers configuration files.""" + + @classmethod + def load_config(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]: + """Load a diffusrs ConfigMixin configuration.""" + cls.config_name = kwargs.pop("config_name") + # Diffusers doesn't provide typing info + return super().load_config(*args, **kwargs) # type: ignore diff --git a/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py b/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py new file mode 100644 index 00000000000..7d25e9d218c --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py @@ -0,0 +1,38 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Class for IP Adapter model loading in InvokeAI.""" + +from pathlib import Path +from typing import Optional + +import torch + +from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter +from invokeai.backend.model_manager import ( + AnyModel, + BaseModelType, + ModelFormat, + ModelRepoVariant, + ModelType, + SubModelType, +) +from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry + + +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI) +class IPAdapterInvokeAILoader(ModelLoader): + """Class to load IP Adapter diffusers models.""" + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if submodel_type is not None: + raise ValueError("There are no submodels in an IP-Adapter model.") + model = build_ip_adapter( + ip_adapter_ckpt_path=model_path / "ip_adapter.bin", + device=torch.device("cpu"), + dtype=self._torch_dtype, + ) + return model diff --git a/invokeai/backend/model_manager/load/model_loaders/lora.py b/invokeai/backend/model_manager/load/model_loaders/lora.py new file mode 100644 index 00000000000..fe804ef5654 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/lora.py @@ -0,0 +1,78 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Class for LoRA model loading in InvokeAI.""" + + +from logging import Logger +from pathlib import Path +from typing import Optional, Tuple + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend.lora import LoRAModelRaw +from invokeai.backend.model_manager import ( + AnyModel, + AnyModelConfig, + BaseModelType, + ModelFormat, + ModelRepoVariant, + ModelType, + SubModelType, +) +from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase +from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase + +from .. import ModelLoader, ModelLoaderRegistry + + +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Diffusers) +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Lycoris) +class LoraLoader(ModelLoader): + """Class to load LoRA models.""" + + # We cheat a little bit to get access to the model base + def __init__( + self, + app_config: InvokeAIAppConfig, + logger: Logger, + ram_cache: ModelCacheBase[AnyModel], + convert_cache: ModelConvertCacheBase, + ): + """Initialize the loader.""" + super().__init__(app_config, logger, ram_cache, convert_cache) + self._model_base: Optional[BaseModelType] = None + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if submodel_type is not None: + raise ValueError("There are no submodels in a LoRA model.") + assert self._model_base is not None + model = LoRAModelRaw.from_checkpoint( + file_path=model_path, + dtype=self._torch_dtype, + base_model=self._model_base, + ) + return model + + # override + def _get_model_path( + self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None + ) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]: + self._model_base = ( + config.base + ) # cheating a little - we remember this variable for using in the subsequent call to _load_model() + + model_base_path = self._app_config.models_path + model_path = model_base_path / config.path + + if config.format == ModelFormat.Diffusers: + for ext in ["safetensors", "bin"]: # return path to the safetensors file inside the folder + path = model_base_path / config.path / f"pytorch_lora_weights.{ext}" + if path.exists(): + model_path = path + break + + result = model_path.resolve(), config, submodel_type + return result diff --git a/invokeai/backend/model_manager/load/model_loaders/onnx.py b/invokeai/backend/model_manager/load/model_loaders/onnx.py new file mode 100644 index 00000000000..38f0274acc6 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/onnx.py @@ -0,0 +1,42 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Class for Onnx model loading in InvokeAI.""" + +# This should work the same as Stable Diffusion pipelines +from pathlib import Path +from typing import Optional + +from invokeai.backend.model_manager import ( + AnyModel, + BaseModelType, + ModelFormat, + ModelRepoVariant, + ModelType, + SubModelType, +) + +from .. import ModelLoaderRegistry +from .generic_diffusers import GenericDiffusersLoader + + +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Onnx) +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive) +class OnnyxDiffusersModel(GenericDiffusersLoader): + """Class to load onnx models.""" + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if not submodel_type is not None: + raise Exception("A submodel type must be provided when loading onnx pipelines.") + load_class = self.get_hf_load_class(model_path, submodel_type) + variant = model_variant.value if model_variant else None + model_path = model_path / submodel_type.value + result: AnyModel = load_class.from_pretrained( + model_path, + torch_dtype=self._torch_dtype, + variant=variant, + ) # type: ignore + return result diff --git a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py new file mode 100644 index 00000000000..9952c883622 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py @@ -0,0 +1,94 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Class for StableDiffusion model loading in InvokeAI.""" + + +from pathlib import Path +from typing import Optional + +from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline + +from invokeai.backend.model_manager import ( + AnyModel, + AnyModelConfig, + BaseModelType, + ModelFormat, + ModelRepoVariant, + ModelType, + ModelVariantType, + SubModelType, +) +from invokeai.backend.model_manager.config import MainCheckpointConfig +from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers + +from .. import ModelLoaderRegistry +from .generic_diffusers import GenericDiffusersLoader + + +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Diffusers) +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Checkpoint) +class StableDiffusionDiffusersModel(GenericDiffusersLoader): + """Class to load main models.""" + + model_base_to_model_type = { + BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder", + BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder", + BaseModelType.StableDiffusionXL: "SDXL", + BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner", + } + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if not submodel_type is not None: + raise Exception("A submodel type must be provided when loading main pipelines.") + load_class = self._get_hf_load_class(model_path, submodel_type) + variant = model_variant.value if model_variant else None + model_path = model_path / submodel_type.value + result: AnyModel = load_class.from_pretrained( + model_path, + torch_dtype=self._torch_dtype, + variant=variant, + ) # type: ignore + return result + + def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool: + if config.format != ModelFormat.Checkpoint: + return False + elif ( + dest_path.exists() + and (dest_path / "model_index.json").stat().st_mtime >= (config.last_modified or 0.0) + and (dest_path / "model_index.json").stat().st_mtime >= model_path.stat().st_mtime + ): + return False + else: + return True + + def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path: + assert isinstance(config, MainCheckpointConfig) + variant = config.variant + base = config.base + pipeline_class = ( + StableDiffusionInpaintPipeline if variant == ModelVariantType.Inpaint else StableDiffusionPipeline + ) + + config_file = config.config + + self._logger.info(f"Converting {model_path} to diffusers format") + convert_ckpt_to_diffusers( + model_path, + output_path, + model_type=self.model_base_to_model_type[base], + model_version=base, + model_variant=variant, + original_config_file=self._app_config.root_path / config_file, + extract_ema=True, + scan_needed=True, + pipeline_class=pipeline_class, + from_safetensors=model_path.suffix == ".safetensors", + precision=self._torch_dtype, + load_safety_checker=False, + ) + return output_path diff --git a/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py b/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py new file mode 100644 index 00000000000..094d4d7c5c3 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py @@ -0,0 +1,57 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Class for TI model loading in InvokeAI.""" + + +from pathlib import Path +from typing import Optional, Tuple + +from invokeai.backend.model_manager import ( + AnyModel, + AnyModelConfig, + BaseModelType, + ModelFormat, + ModelRepoVariant, + ModelType, + SubModelType, +) +from invokeai.backend.textual_inversion import TextualInversionModelRaw + +from .. import ModelLoader, ModelLoaderRegistry + + +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFile) +@ModelLoaderRegistry.register( + base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFolder +) +class TextualInversionLoader(ModelLoader): + """Class to load TI models.""" + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if submodel_type is not None: + raise ValueError("There are no submodels in a TI model.") + model = TextualInversionModelRaw.from_checkpoint( + file_path=model_path, + dtype=self._torch_dtype, + ) + return model + + # override + def _get_model_path( + self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None + ) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]: + model_path = self._app_config.models_path / config.path + + if config.format == ModelFormat.EmbeddingFolder: + path = model_path / "learned_embeds.bin" + else: + path = model_path + + if not path.exists(): + raise OSError(f"The embedding file at {path} was not found") + + return path, config, submodel_type diff --git a/invokeai/backend/model_manager/load/model_loaders/vae.py b/invokeai/backend/model_manager/load/model_loaders/vae.py new file mode 100644 index 00000000000..7ade1494eb1 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/vae.py @@ -0,0 +1,68 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Class for VAE model loading in InvokeAI.""" + +from pathlib import Path + +import safetensors +import torch +from omegaconf import DictConfig, OmegaConf + +from invokeai.backend.model_manager import ( + AnyModelConfig, + BaseModelType, + ModelFormat, + ModelType, +) +from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers + +from .. import ModelLoaderRegistry +from .generic_diffusers import GenericDiffusersLoader + + +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers) +@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint) +@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint) +class VaeLoader(GenericDiffusersLoader): + """Class to load VAE models.""" + + def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool: + if config.format != ModelFormat.Checkpoint: + return False + elif ( + dest_path.exists() + and (dest_path / "config.json").stat().st_mtime >= (config.last_modified or 0.0) + and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime + ): + return False + else: + return True + + def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path: + # TO DO: check whether sdxl VAE models convert. + if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}: + raise Exception(f"Vae conversion not supported for model type: {config.base}") + else: + config_file = ( + "v1-inference.yaml" if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml" + ) + + if model_path.suffix == ".safetensors": + checkpoint = safetensors.torch.load_file(model_path, device="cpu") + else: + checkpoint = torch.load(model_path, map_location="cpu") + + # sometimes weights are hidden under "state_dict", and sometimes not + if "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + ckpt_config = OmegaConf.load(self._app_config.legacy_conf_path / config_file) + assert isinstance(ckpt_config, DictConfig) + + vae_model = convert_ldm_vae_to_diffusers( + checkpoint=checkpoint, + vae_config=ckpt_config, + image_size=512, + ) + vae_model.to(self._torch_dtype) # set precision appropriately + vae_model.save_pretrained(output_path, safe_serialization=True) + return output_path diff --git a/invokeai/backend/model_manager/load/model_util.py b/invokeai/backend/model_manager/load/model_util.py new file mode 100644 index 00000000000..c55eee48fa5 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_util.py @@ -0,0 +1,113 @@ +# Copyright (c) 2024 The InvokeAI Development Team +"""Various utility functions needed by the loader and caching system.""" + +import json +from pathlib import Path +from typing import Optional + +import torch +from diffusers import DiffusionPipeline + +from invokeai.backend.model_manager.config import AnyModel +from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel + + +def calc_model_size_by_data(model: AnyModel) -> int: + """Get size of a model in memory in bytes.""" + if isinstance(model, DiffusionPipeline): + return _calc_pipeline_by_data(model) + elif isinstance(model, torch.nn.Module): + return _calc_model_by_data(model) + elif isinstance(model, IAIOnnxRuntimeModel): + return _calc_onnx_model_by_data(model) + else: + return 0 + + +def _calc_pipeline_by_data(pipeline: DiffusionPipeline) -> int: + res = 0 + assert hasattr(pipeline, "components") + for submodel_key in pipeline.components.keys(): + submodel = getattr(pipeline, submodel_key) + if submodel is not None and isinstance(submodel, torch.nn.Module): + res += _calc_model_by_data(submodel) + return res + + +def _calc_model_by_data(model: torch.nn.Module) -> int: + mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()]) + mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()]) + mem: int = mem_params + mem_bufs # in bytes + return mem + + +def _calc_onnx_model_by_data(model: IAIOnnxRuntimeModel) -> int: + tensor_size = model.tensors.size() * 2 # The session doubles this + mem = tensor_size # in bytes + return mem + + +def calc_model_size_by_fs(model_path: Path, subfolder: Optional[str] = None, variant: Optional[str] = None) -> int: + """Estimate the size of a model on disk in bytes.""" + if model_path.is_file(): + return model_path.stat().st_size + + if subfolder is not None: + model_path = model_path / subfolder + + # this can happen when, for example, the safety checker is not downloaded. + if not model_path.exists(): + return 0 + + all_files = [f for f in model_path.iterdir() if (model_path / f).is_file()] + + fp16_files = {f for f in all_files if ".fp16." in f.name or ".fp16-" in f.name} + bit8_files = {f for f in all_files if ".8bit." in f.name or ".8bit-" in f.name} + other_files = set(all_files) - fp16_files - bit8_files + + if not variant: # ModelRepoVariant.DEFAULT evaluates to empty string for compatability with HF + files = other_files + elif variant == "fp16": + files = fp16_files + elif variant == "8bit": + files = bit8_files + else: + raise NotImplementedError(f"Unknown variant: {variant}") + + # try read from index if exists + index_postfix = ".index.json" + if variant is not None: + index_postfix = f".index.{variant}.json" + + for file in files: + if not file.name.endswith(index_postfix): + continue + try: + with open(model_path / file, "r") as f: + index_data = json.loads(f.read()) + return int(index_data["metadata"]["total_size"]) + except Exception: + pass + + # calculate files size if there is no index file + formats = [ + (".safetensors",), # safetensors + (".bin",), # torch + (".onnx", ".pb"), # onnx + (".msgpack",), # flax + (".ckpt",), # tf + (".h5",), # tf2 + ] + + for file_format in formats: + model_files = [f for f in files if f.suffix in file_format] + if len(model_files) == 0: + continue + + model_size = 0 + for model_file in model_files: + file_stats = (model_path / model_file).stat() + model_size += file_stats.st_size + return model_size + + return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu diff --git a/invokeai/backend/model_management/model_load_optimizations.py b/invokeai/backend/model_manager/load/optimizations.py similarity index 64% rename from invokeai/backend/model_management/model_load_optimizations.py rename to invokeai/backend/model_manager/load/optimizations.py index a46d262175f..030fcfa639a 100644 --- a/invokeai/backend/model_management/model_load_optimizations.py +++ b/invokeai/backend/model_manager/load/optimizations.py @@ -1,16 +1,16 @@ from contextlib import contextmanager +from typing import Any, Generator import torch -def _no_op(*args, **kwargs): +def _no_op(*args: Any, **kwargs: Any) -> None: pass @contextmanager -def skip_torch_weight_init(): - """A context manager that monkey-patches several of the common torch layers (torch.nn.Linear, torch.nn.Conv1d, etc.) - to skip weight initialization. +def skip_torch_weight_init() -> Generator[None, None, None]: + """Monkey patch several of the common torch layers (torch.nn.Linear, torch.nn.Conv1d, etc.) to skip weight initialization. By default, `torch.nn.Linear` and `torch.nn.ConvNd` layers initialize their weights (according to a particular distribution) when __init__ is called. This weight initialization step can take a significant amount of time, and is @@ -18,13 +18,14 @@ def skip_torch_weight_init(): monkey-patches common torch layers to skip the weight initialization step. """ torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding] - saved_functions = [m.reset_parameters for m in torch_modules] + saved_functions = [hasattr(m, "reset_parameters") and m.reset_parameters for m in torch_modules] try: for torch_module in torch_modules: + assert hasattr(torch_module, "reset_parameters") torch_module.reset_parameters = _no_op - yield None finally: for torch_module, saved_function in zip(torch_modules, saved_functions, strict=True): + assert hasattr(torch_module, "reset_parameters") torch_module.reset_parameters = saved_function diff --git a/invokeai/backend/model_manager/merge.py b/invokeai/backend/model_manager/merge.py index 2c94af4af3b..7063cb907d2 100644 --- a/invokeai/backend/model_manager/merge.py +++ b/invokeai/backend/model_manager/merge.py @@ -13,7 +13,7 @@ import torch from diffusers import AutoPipelineForText2Image -from diffusers import logging as dlogging +from diffusers.utils import logging as dlogging from invokeai.app.services.model_install import ModelInstallServiceBase from invokeai.backend.util.devices import choose_torch_device, torch_dtype @@ -39,10 +39,7 @@ class ModelMerger(object): def __init__(self, installer: ModelInstallServiceBase): """ - Initialize a ModelMerger object. - - :param store: Underlying storage manager for the running process. - :param config: InvokeAIAppConfig object (if not provided, default will be selected). + Initialize a ModelMerger object with the model installer. """ self._installer = installer @@ -79,7 +76,7 @@ def merge_diffusion_models( custom_pipeline="checkpoint_merger", torch_dtype=dtype, variant=variant, - ) + ) # type: ignore merged_pipe = pipe.merge( pretrained_model_name_or_path_list=model_paths, alpha=alpha, diff --git a/invokeai/backend/model_manager/metadata/__init__.py b/invokeai/backend/model_manager/metadata/__init__.py index 672e378c7fe..a35e55f3d24 100644 --- a/invokeai/backend/model_manager/metadata/__init__.py +++ b/invokeai/backend/model_manager/metadata/__init__.py @@ -18,7 +18,7 @@ if data.allow_commercial_use: print("Commercial use of this model is allowed") """ -from .fetch import CivitaiMetadataFetch, HuggingFaceMetadataFetch +from .fetch import CivitaiMetadataFetch, HuggingFaceMetadataFetch, ModelMetadataFetchBase from .metadata_base import ( AnyModelRepoMetadata, AnyModelRepoMetadataValidator, @@ -31,7 +31,6 @@ RemoteModelFile, UnknownMetadataException, ) -from .metadata_store import ModelMetadataStore __all__ = [ "AnyModelRepoMetadata", @@ -42,7 +41,7 @@ "HuggingFaceMetadata", "HuggingFaceMetadataFetch", "LicenseRestrictions", - "ModelMetadataStore", + "ModelMetadataFetchBase", "BaseMetadata", "ModelMetadataWithFiles", "RemoteModelFile", diff --git a/invokeai/backend/model_manager/metadata/fetch/civitai.py b/invokeai/backend/model_manager/metadata/fetch/civitai.py index 6e41d6f11b2..7991f6a7489 100644 --- a/invokeai/backend/model_manager/metadata/fetch/civitai.py +++ b/invokeai/backend/model_manager/metadata/fetch/civitai.py @@ -32,6 +32,8 @@ from pydantic.networks import AnyHttpUrl from requests.sessions import Session +from invokeai.backend.model_manager import ModelRepoVariant + from ..metadata_base import ( AnyModelRepoMetadata, CivitaiMetadata, @@ -82,10 +84,13 @@ def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata: return self.from_civitai_versionid(int(version_id)) raise UnknownMetadataException("The url '{url}' does not match any known Civitai URL patterns") - def from_id(self, id: str) -> AnyModelRepoMetadata: + def from_id(self, id: str, variant: Optional[ModelRepoVariant] = None) -> AnyModelRepoMetadata: """ Given a Civitai model version ID, return a ModelRepoMetadata object. + :param id: An ID. + :param variant: A model variant from the ModelRepoVariant enum (currently ignored) + May raise an `UnknownMetadataException`. """ return self.from_civitai_versionid(int(id)) diff --git a/invokeai/backend/model_manager/metadata/fetch/fetch_base.py b/invokeai/backend/model_manager/metadata/fetch/fetch_base.py index 58b65b69477..5d75493b92f 100644 --- a/invokeai/backend/model_manager/metadata/fetch/fetch_base.py +++ b/invokeai/backend/model_manager/metadata/fetch/fetch_base.py @@ -18,7 +18,9 @@ from pydantic.networks import AnyHttpUrl from requests.sessions import Session -from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator +from invokeai.backend.model_manager import ModelRepoVariant + +from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator, BaseMetadata class ModelMetadataFetchBase(ABC): @@ -45,10 +47,13 @@ def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata: pass @abstractmethod - def from_id(self, id: str) -> AnyModelRepoMetadata: + def from_id(self, id: str, variant: Optional[ModelRepoVariant] = None) -> AnyModelRepoMetadata: """ Given an ID for a model, return a ModelMetadata object. + :param id: An ID. + :param variant: A model variant from the ModelRepoVariant enum. + This method will raise a `UnknownMetadataException` in the event that the requested model's metadata is not found at the provided id. """ @@ -57,5 +62,5 @@ def from_id(self, id: str) -> AnyModelRepoMetadata: @classmethod def from_json(cls, json: str) -> AnyModelRepoMetadata: """Given the JSON representation of the metadata, return the corresponding Pydantic object.""" - metadata = AnyModelRepoMetadataValidator.validate_json(json) + metadata: BaseMetadata = AnyModelRepoMetadataValidator.validate_json(json) # type: ignore return metadata diff --git a/invokeai/backend/model_manager/metadata/fetch/huggingface.py b/invokeai/backend/model_manager/metadata/fetch/huggingface.py index 5d1eb0cc9e4..6f04e8713b2 100644 --- a/invokeai/backend/model_manager/metadata/fetch/huggingface.py +++ b/invokeai/backend/model_manager/metadata/fetch/huggingface.py @@ -19,10 +19,12 @@ import requests from huggingface_hub import HfApi, configure_http_backend, hf_hub_url -from huggingface_hub.utils._errors import RepositoryNotFoundError +from huggingface_hub.utils._errors import RepositoryNotFoundError, RevisionNotFoundError from pydantic.networks import AnyHttpUrl from requests.sessions import Session +from invokeai.backend.model_manager import ModelRepoVariant + from ..metadata_base import ( AnyModelRepoMetadata, HuggingFaceMetadata, @@ -53,12 +55,22 @@ def from_json(cls, json: str) -> HuggingFaceMetadata: metadata = HuggingFaceMetadata.model_validate_json(json) return metadata - def from_id(self, id: str) -> AnyModelRepoMetadata: + def from_id(self, id: str, variant: Optional[ModelRepoVariant] = None) -> AnyModelRepoMetadata: """Return a HuggingFaceMetadata object given the model's repo_id.""" - try: - model_info = HfApi().model_info(repo_id=id, files_metadata=True) - except RepositoryNotFoundError as excp: - raise UnknownMetadataException(f"'{id}' not found. See trace for details.") from excp + # Little loop which tries fetching a revision corresponding to the selected variant. + # If not available, then set variant to None and get the default. + # If this too fails, raise exception. + model_info = None + while not model_info: + try: + model_info = HfApi().model_info(repo_id=id, files_metadata=True, revision=variant) + except RepositoryNotFoundError as excp: + raise UnknownMetadataException(f"'{id}' not found. See trace for details.") from excp + except RevisionNotFoundError: + if variant is None: + raise + else: + variant = None _, name = id.split("/") return HuggingFaceMetadata( @@ -70,7 +82,7 @@ def from_id(self, id: str) -> AnyModelRepoMetadata: tags=model_info.tags, files=[ RemoteModelFile( - url=hf_hub_url(id, x.rfilename), + url=hf_hub_url(id, x.rfilename, revision=variant), path=Path(name, x.rfilename), size=x.size, sha256=x.lfs.get("sha256") if x.lfs else None, diff --git a/invokeai/backend/model_manager/metadata/metadata_base.py b/invokeai/backend/model_manager/metadata/metadata_base.py index 5aa883d26d0..6e410d82220 100644 --- a/invokeai/backend/model_manager/metadata/metadata_base.py +++ b/invokeai/backend/model_manager/metadata/metadata_base.py @@ -54,8 +54,8 @@ class LicenseRestrictions(BaseModel): AllowDifferentLicense: bool = Field( description="if true, derivatives of this model be redistributed under a different license", default=False ) - AllowCommercialUse: CommercialUsage = Field( - description="Type of commercial use allowed or 'No' if no commercial use is allowed.", default_factory=set + AllowCommercialUse: Optional[CommercialUsage] = Field( + description="Type of commercial use allowed or 'No' if no commercial use is allowed.", default=None ) @@ -139,7 +139,10 @@ def credit_required(self) -> bool: @property def allow_commercial_use(self) -> bool: """Return True if commercial use is allowed.""" - return self.restrictions.AllowCommercialUse != CommercialUsage("None") + if self.restrictions.AllowCommercialUse is None: + return False + else: + return self.restrictions.AllowCommercialUse != CommercialUsage("None") @property def allow_derivatives(self) -> bool: @@ -184,7 +187,6 @@ def download_urls( [x.path for x in self.files], variant, subfolder ) # all files in the model prefix = f"{subfolder}/" if subfolder else "" - # the next step reads model_index.json to determine which subdirectories belong # to the model if Path(f"{prefix}model_index.json") in paths: diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index cd048d2fe78..7de4289466d 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -7,9 +7,7 @@ import torch from picklescan.scanner import scan_file_path -from invokeai.backend.model_management.models.base import read_checkpoint_meta -from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat -from invokeai.backend.model_management.util import lora_token_vector_length +import invokeai.backend.util.logging as logger from invokeai.backend.util.util import SilenceWarnings from .config import ( @@ -18,18 +16,24 @@ InvalidModelConfigException, ModelConfigFactory, ModelFormat, + ModelRepoVariant, ModelType, ModelVariantType, SchedulerPredictionType, ) from .hash import FastModelHash +from .util.model_util import lora_token_vector_length, read_checkpoint_meta CkptType = Dict[str, Any] LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[SchedulerPredictionType, str]]]] = { BaseModelType.StableDiffusion1: { - ModelVariantType.Normal: "v1-inference.yaml", + ModelVariantType.Normal: { + SchedulerPredictionType.Epsilon: "v1-inference.yaml", + SchedulerPredictionType.VPrediction: "v1-inference-v.yaml", + }, ModelVariantType.Inpaint: "v1-inpainting-inference.yaml", + ModelVariantType.Depth: "v2-midas-inference.yaml", }, BaseModelType.StableDiffusion2: { ModelVariantType.Normal: { @@ -72,6 +76,10 @@ def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]: """Get model scheduler prediction type.""" return None + def get_image_encoder_model_id(self) -> Optional[str]: + """Get image encoder (IP adapters only).""" + return None + class ModelProbe(object): PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = { @@ -147,6 +155,7 @@ def probe( fields["base"] = fields.get("base") or probe.get_base_type() fields["variant"] = fields.get("variant") or probe.get_variant_type() fields["prediction_type"] = fields.get("prediction_type") or probe.get_scheduler_prediction_type() + fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id() fields["name"] = fields.get("name") or cls.get_model_name(model_path) fields["description"] = ( fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}" @@ -155,6 +164,9 @@ def probe( fields["original_hash"] = fields.get("original_hash") or hash fields["current_hash"] = fields.get("current_hash") or hash + if format_type == ModelFormat.Diffusers and hasattr(probe, "get_repo_variant"): + fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant() + # additional fields needed for main and controlnet models if fields["type"] in [ModelType.Main, ModelType.ControlNet] and fields["format"] == ModelFormat.Checkpoint: fields["config"] = cls._get_checkpoint_config_path( @@ -477,6 +489,21 @@ def get_variant_type(self) -> ModelVariantType: def get_format(self) -> ModelFormat: return ModelFormat("diffusers") + def get_repo_variant(self) -> ModelRepoVariant: + # get all files ending in .bin or .safetensors + weight_files = list(self.model_path.glob("**/*.safetensors")) + weight_files.extend(list(self.model_path.glob("**/*.bin"))) + for x in weight_files: + if ".fp16" in x.suffixes: + return ModelRepoVariant.FP16 + if "openvino_model" in x.name: + return ModelRepoVariant.OPENVINO + if "flax_model" in x.name: + return ModelRepoVariant.FLAX + if x.suffix == ".onnx": + return ModelRepoVariant.ONNX + return ModelRepoVariant.DEFAULT + class PipelineFolderProbe(FolderProbeBase): def get_base_type(self) -> BaseModelType: @@ -567,13 +594,20 @@ def get_base_type(self) -> BaseModelType: return TextualInversionCheckpointProbe(path).get_base_type() -class ONNXFolderProbe(FolderProbeBase): +class ONNXFolderProbe(PipelineFolderProbe): + def get_base_type(self) -> BaseModelType: + # Due to the way the installer is set up, the configuration file for safetensors + # will come along for the ride if both the onnx and safetensors forms + # share the same directory. We take advantage of this here. + if (self.model_path / "unet" / "config.json").exists(): + return super().get_base_type() + else: + logger.warning('Base type probing is not implemented for ONNX models. Assuming "sd-1"') + return BaseModelType.StableDiffusion1 + def get_format(self) -> ModelFormat: return ModelFormat("onnx") - def get_base_type(self) -> BaseModelType: - return BaseModelType.StableDiffusion1 - def get_variant_type(self) -> ModelVariantType: return ModelVariantType.Normal @@ -617,8 +651,8 @@ def get_base_type(self) -> BaseModelType: class IPAdapterFolderProbe(FolderProbeBase): - def get_format(self) -> IPAdapterModelFormat: - return IPAdapterModelFormat.InvokeAI.value + def get_format(self) -> ModelFormat: + return ModelFormat.InvokeAI def get_base_type(self) -> BaseModelType: model_file = self.model_path / "ip_adapter.bin" @@ -638,6 +672,14 @@ def get_base_type(self) -> BaseModelType: f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}." ) + def get_image_encoder_model_id(self) -> Optional[str]: + encoder_id_path = self.model_path / "image_encoder.txt" + if not encoder_id_path.exists(): + return None + with open(encoder_id_path, "r") as f: + image_encoder_model = f.readline().strip() + return image_encoder_model + class CLIPVisionFolderProbe(FolderProbeBase): def get_base_type(self) -> BaseModelType: diff --git a/invokeai/backend/model_manager/search.py b/invokeai/backend/model_manager/search.py index 4cc3caebe47..f7ef2e049d4 100644 --- a/invokeai/backend/model_manager/search.py +++ b/invokeai/backend/model_manager/search.py @@ -22,6 +22,7 @@ def find_main_models(model: Path) -> bool: import os from abc import ABC, abstractmethod +from logging import Logger from pathlib import Path from typing import Callable, Optional, Set, Union @@ -29,7 +30,7 @@ def find_main_models(model: Path) -> bool: from invokeai.backend.util.logging import InvokeAILogger -default_logger = InvokeAILogger.get_logger() +default_logger: Logger = InvokeAILogger.get_logger() class SearchStats(BaseModel): @@ -56,7 +57,7 @@ class ModelSearchBase(ABC, BaseModel): on_model_found : Optional[Callable[[Path], bool]] = Field(default=None, description="Called when a model is found.") # noqa E221 on_search_completed : Optional[Callable[[Set[Path]], None]] = Field(default=None, description="Called when search is complete.") # noqa E221 stats : SearchStats = Field(default_factory=SearchStats, description="Summary statistics after search") # noqa E221 - logger : InvokeAILogger = Field(default=default_logger, description="Logger instance.") # noqa E221 + logger : Logger = Field(default=default_logger, description="Logger instance.") # noqa E221 # fmt: on class Config: @@ -115,9 +116,9 @@ class ModelSearch(ModelSearchBase): # returns all models that have 'anime' in the path """ - models_found: Set[Path] = Field(default=None) - scanned_dirs: Set[Path] = Field(default=None) - pruned_paths: Set[Path] = Field(default=None) + models_found: Set[Path] = Field(default_factory=set) + scanned_dirs: Set[Path] = Field(default_factory=set) + pruned_paths: Set[Path] = Field(default_factory=set) def search_started(self) -> None: self.models_found = set() @@ -128,13 +129,13 @@ def search_started(self) -> None: def model_found(self, model: Path) -> None: self.stats.models_found += 1 - if not self.on_model_found or self.on_model_found(model): + if self.on_model_found is None or self.on_model_found(model): self.stats.models_filtered += 1 self.models_found.add(model) def search_completed(self) -> None: - if self.on_search_completed: - self.on_search_completed(self._models_found) + if self.on_search_completed is not None: + self.on_search_completed(self.models_found) def search(self, directory: Union[Path, str]) -> Set[Path]: self._directory = Path(directory) diff --git a/invokeai/backend/model_manager/util/libc_util.py b/invokeai/backend/model_manager/util/libc_util.py new file mode 100644 index 00000000000..ef1ac2f8a4b --- /dev/null +++ b/invokeai/backend/model_manager/util/libc_util.py @@ -0,0 +1,76 @@ +import ctypes + + +class Struct_mallinfo2(ctypes.Structure): + """A ctypes Structure that matches the libc mallinfo2 struct. + + Docs: + - https://man7.org/linux/man-pages/man3/mallinfo.3.html + - https://www.gnu.org/software/libc/manual/html_node/Statistics-of-Malloc.html + + struct mallinfo2 { + size_t arena; /* Non-mmapped space allocated (bytes) */ + size_t ordblks; /* Number of free chunks */ + size_t smblks; /* Number of free fastbin blocks */ + size_t hblks; /* Number of mmapped regions */ + size_t hblkhd; /* Space allocated in mmapped regions (bytes) */ + size_t usmblks; /* See below */ + size_t fsmblks; /* Space in freed fastbin blocks (bytes) */ + size_t uordblks; /* Total allocated space (bytes) */ + size_t fordblks; /* Total free space (bytes) */ + size_t keepcost; /* Top-most, releasable space (bytes) */ + }; + """ + + _fields_ = [ + ("arena", ctypes.c_size_t), + ("ordblks", ctypes.c_size_t), + ("smblks", ctypes.c_size_t), + ("hblks", ctypes.c_size_t), + ("hblkhd", ctypes.c_size_t), + ("usmblks", ctypes.c_size_t), + ("fsmblks", ctypes.c_size_t), + ("uordblks", ctypes.c_size_t), + ("fordblks", ctypes.c_size_t), + ("keepcost", ctypes.c_size_t), + ] + + def __str__(self) -> str: + s = "" + s += f"{'arena': <10}= {(self.arena/2**30):15.5f} # Non-mmapped space allocated (GB) (uordblks + fordblks)\n" + s += f"{'ordblks': <10}= {(self.ordblks): >15} # Number of free chunks\n" + s += f"{'smblks': <10}= {(self.smblks): >15} # Number of free fastbin blocks \n" + s += f"{'hblks': <10}= {(self.hblks): >15} # Number of mmapped regions \n" + s += f"{'hblkhd': <10}= {(self.hblkhd/2**30):15.5f} # Space allocated in mmapped regions (GB)\n" + s += f"{'usmblks': <10}= {(self.usmblks): >15} # Unused\n" + s += f"{'fsmblks': <10}= {(self.fsmblks/2**30):15.5f} # Space in freed fastbin blocks (GB)\n" + s += ( + f"{'uordblks': <10}= {(self.uordblks/2**30):15.5f} # Space used by in-use allocations (non-mmapped)" + " (GB)\n" + ) + s += f"{'fordblks': <10}= {(self.fordblks/2**30):15.5f} # Space in free blocks (non-mmapped) (GB)\n" + s += f"{'keepcost': <10}= {(self.keepcost/2**30):15.5f} # Top-most, releasable space (GB)\n" + return s + + +class LibcUtil: + """A utility class for interacting with the C Standard Library (`libc`) via ctypes. + + Note that this class will raise on __init__() if 'libc.so.6' can't be found. Take care to handle environments where + this shared library is not available. + + TODO: Improve cross-OS compatibility of this class. + """ + + def __init__(self) -> None: + self._libc = ctypes.cdll.LoadLibrary("libc.so.6") + + def mallinfo2(self) -> Struct_mallinfo2: + """Calls `libc` `mallinfo2`. + + Docs: https://man7.org/linux/man-pages/man3/mallinfo.3.html + """ + mallinfo2 = self._libc.mallinfo2 + mallinfo2.restype = Struct_mallinfo2 + result: Struct_mallinfo2 = mallinfo2() + return result diff --git a/invokeai/backend/model_management/util.py b/invokeai/backend/model_manager/util/model_util.py similarity index 56% rename from invokeai/backend/model_management/util.py rename to invokeai/backend/model_manager/util/model_util.py index f4737d9f0b5..2e448520e56 100644 --- a/invokeai/backend/model_management/util.py +++ b/invokeai/backend/model_manager/util/model_util.py @@ -1,15 +1,71 @@ -# Copyright (c) 2023 The InvokeAI Development Team -"""Utilities used by the Model Manager""" - - -def lora_token_vector_length(checkpoint: dict) -> int: +"""Utilities for parsing model files, used mostly by probe.py""" + +import json +from pathlib import Path +from typing import Dict, Optional, Union + +import safetensors +import torch +from picklescan.scanner import scan_file_path + + +def _fast_safetensors_reader(path: str) -> Dict[str, torch.Tensor]: + checkpoint = {} + device = torch.device("meta") + with open(path, "rb") as f: + definition_len = int.from_bytes(f.read(8), "little") + definition_json = f.read(definition_len) + definition = json.loads(definition_json) + + if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in { + "pt", + "torch", + "pytorch", + }: + raise Exception("Supported only pytorch safetensors files") + definition.pop("__metadata__", None) + + for key, info in definition.items(): + dtype = { + "I8": torch.int8, + "I16": torch.int16, + "I32": torch.int32, + "I64": torch.int64, + "F16": torch.float16, + "F32": torch.float32, + "F64": torch.float64, + }[info["dtype"]] + + checkpoint[key] = torch.empty(info["shape"], dtype=dtype, device=device) + + return checkpoint + + +def read_checkpoint_meta(path: Union[str, Path], scan: bool = False) -> Dict[str, torch.Tensor]: + if str(path).endswith(".safetensors"): + try: + path_str = path.as_posix() if isinstance(path, Path) else path + checkpoint = _fast_safetensors_reader(path_str) + except Exception: + # TODO: create issue for support "meta"? + checkpoint = safetensors.torch.load_file(path, device="cpu") + else: + if scan: + scan_result = scan_file_path(path) + if scan_result.infected_files != 0: + raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.') + checkpoint = torch.load(path, map_location=torch.device("meta")) + return checkpoint + + +def lora_token_vector_length(checkpoint: Dict[str, torch.Tensor]) -> Optional[int]: """ Given a checkpoint in memory, return the lora token vector length :param checkpoint: The checkpoint """ - def _get_shape_1(key: str, tensor, checkpoint) -> int: + def _get_shape_1(key: str, tensor: torch.Tensor, checkpoint: Dict[str, torch.Tensor]) -> Optional[int]: lora_token_vector_length = None if "." not in key: diff --git a/invokeai/backend/model_manager/util/select_hf_files.py b/invokeai/backend/model_manager/util/select_hf_files.py index 69760590440..2fd7a3721ab 100644 --- a/invokeai/backend/model_manager/util/select_hf_files.py +++ b/invokeai/backend/model_manager/util/select_hf_files.py @@ -36,23 +36,37 @@ def filter_files( """ variant = variant or ModelRepoVariant.DEFAULT paths: List[Path] = [] + root = files[0].parts[0] + + # if the subfolder is a single file, then bypass the selection and just return it + if subfolder and subfolder.suffix in [".safetensors", ".bin", ".onnx", ".xml", ".pth", ".pt", ".ckpt", ".msgpack"]: + return [root / subfolder] # Start by filtering on model file extensions, discarding images, docs, etc for file in files: if file.name.endswith((".json", ".txt")): paths.append(file) - elif file.name.endswith(("learned_embeds.bin", "ip_adapter.bin", "lora_weights.safetensors")): + elif file.name.endswith( + ( + "learned_embeds.bin", + "ip_adapter.bin", + "lora_weights.safetensors", + "weights.pb", + "onnx_data", + ) + ): paths.append(file) # BRITTLENESS WARNING!! # Diffusers models always seem to have "model" in their name, and the regex filter below is applied to avoid # downloading random checkpoints that might also be in the repo. However there is no guarantee # that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models - # will adhere to this naming convention, so this is an area of brittleness. + # will adhere to this naming convention, so this is an area to be careful of. elif re.search(r"model(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name): paths.append(file) # limit search to subfolder if requested if subfolder: + subfolder = root / subfolder paths = [x for x in paths if x.parent == Path(subfolder)] # _filter_by_variant uniquifies the paths and returns a set @@ -64,7 +78,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path result = set() basenames: Dict[Path, Path] = {} for path in files: - if path.suffix == ".onnx": + if path.suffix in [".onnx", ".pb", ".onnx_data"]: if variant == ModelRepoVariant.ONNX: result.add(path) diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_patcher.py similarity index 79% rename from invokeai/backend/model_management/lora.py rename to invokeai/backend/model_patcher.py index d72f55794d3..bee8909c311 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_patcher.py @@ -1,21 +1,23 @@ +# Copyright (c) 2024 Ryan Dick, Lincoln D. Stein, and the InvokeAI Development Team +"""These classes implement model patching with LoRAs and Textual Inversions.""" from __future__ import annotations import pickle from contextlib import contextmanager -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Iterator, List, Optional, Tuple import numpy as np import torch -from compel.embeddings_provider import BaseTextualInversionManager -from diffusers.models import UNet2DConditionModel -from safetensors.torch import load_file +from diffusers import OnnxRuntimeModel, UNet2DConditionModel from transformers import CLIPTextModel, CLIPTokenizer from invokeai.app.shared.models import FreeUConfig -from invokeai.backend.model_management.model_load_optimizations import skip_torch_weight_init +from invokeai.backend.model_manager import AnyModel +from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init +from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel -from .models.lora import LoRAModel +from .lora import LoRAModelRaw +from .textual_inversion import TextualInversionManager, TextualInversionModelRaw """ loras = [ @@ -62,8 +64,8 @@ def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tup def apply_lora_unet( cls, unet: UNet2DConditionModel, - loras: List[Tuple[LoRAModel, float]], - ): + loras: Iterator[Tuple[LoRAModelRaw, float]], + ) -> None: with cls.apply_lora(unet, loras, "lora_unet_"): yield @@ -72,8 +74,8 @@ def apply_lora_unet( def apply_lora_text_encoder( cls, text_encoder: CLIPTextModel, - loras: List[Tuple[LoRAModel, float]], - ): + loras: Iterator[Tuple[LoRAModelRaw, float]], + ) -> None: with cls.apply_lora(text_encoder, loras, "lora_te_"): yield @@ -82,8 +84,8 @@ def apply_lora_text_encoder( def apply_sdxl_lora_text_encoder( cls, text_encoder: CLIPTextModel, - loras: List[Tuple[LoRAModel, float]], - ): + loras: List[Tuple[LoRAModelRaw, float]], + ) -> None: with cls.apply_lora(text_encoder, loras, "lora_te1_"): yield @@ -92,8 +94,8 @@ def apply_sdxl_lora_text_encoder( def apply_sdxl_lora_text_encoder2( cls, text_encoder: CLIPTextModel, - loras: List[Tuple[LoRAModel, float]], - ): + loras: List[Tuple[LoRAModelRaw, float]], + ) -> None: with cls.apply_lora(text_encoder, loras, "lora_te2_"): yield @@ -101,10 +103,10 @@ def apply_sdxl_lora_text_encoder2( @contextmanager def apply_lora( cls, - model: torch.nn.Module, - loras: List[Tuple[LoRAModel, float]], + model: AnyModel, + loras: Iterator[Tuple[LoRAModelRaw, float]], prefix: str, - ): + ) -> None: original_weights = {} try: with torch.no_grad(): @@ -121,6 +123,7 @@ def apply_lora( # 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the # intricacies of Stable Diffusion key resolution. It should just expect the input LoRA # weights to have valid keys. + assert isinstance(model, torch.nn.Module) module_key, module = cls._resolve_lora_key(model, layer_key, prefix) # All of the LoRA weight calculations will be done on the same device as the module weight. @@ -141,17 +144,21 @@ def apply_lora( # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed. layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale) - layer.to(device="cpu") + layer.to(device=torch.device("cpu")) + assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! if module.weight.shape != layer_weight.shape: # TODO: debug on lycoris + assert hasattr(layer_weight, "reshape") layer_weight = layer_weight.reshape(module.weight.shape) + assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! module.weight += layer_weight.to(dtype=dtype) yield # wait for context manager exit finally: + assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule() with torch.no_grad(): for module_key, weight in original_weights.items(): model.get_submodule(module_key).weight.copy_(weight) @@ -162,8 +169,8 @@ def apply_ti( cls, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, - ti_list: List[Tuple[str, Any]], - ) -> Tuple[CLIPTokenizer, TextualInversionManager]: + ti_list: List[Tuple[str, TextualInversionModelRaw]], + ) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]: init_tokens_count = None new_tokens_added = None @@ -187,13 +194,13 @@ def apply_ti( ti_manager = TextualInversionManager(ti_tokenizer) init_tokens_count = text_encoder.resize_token_embeddings(None, pad_to_multiple_of).num_embeddings - def _get_trigger(ti_name, index): + def _get_trigger(ti_name: str, index: int) -> str: trigger = ti_name if index > 0: trigger += f"-!pad-{i}" return f"<{trigger}>" - def _get_ti_embedding(model_embeddings, ti): + def _get_ti_embedding(model_embeddings: torch.nn.Module, ti: TextualInversionModelRaw) -> torch.Tensor: # for SDXL models, select the embedding that matches the text encoder's dimensions if ti.embedding_2 is not None: return ( @@ -221,6 +228,7 @@ def _get_ti_embedding(model_embeddings, ti): model_embeddings = text_encoder.get_input_embeddings() for ti_name, ti in ti_list: + assert isinstance(ti, TextualInversionModelRaw) ti_embedding = _get_ti_embedding(text_encoder.get_input_embeddings(), ti) ti_tokens = [] @@ -259,7 +267,7 @@ def apply_clip_skip( cls, text_encoder: CLIPTextModel, clip_skip: int, - ): + ) -> None: skipped_layers = [] try: for _i in range(clip_skip): @@ -277,9 +285,10 @@ def apply_freeu( cls, unet: UNet2DConditionModel, freeu_config: Optional[FreeUConfig] = None, - ): + ) -> None: did_apply_freeu = False try: + assert hasattr(unet, "enable_freeu") # mypy doesn't pick up this attribute? if freeu_config is not None: unet.enable_freeu(b1=freeu_config.b1, b2=freeu_config.b2, s1=freeu_config.s1, s2=freeu_config.s2) did_apply_freeu = True @@ -287,109 +296,19 @@ def apply_freeu( yield finally: + assert hasattr(unet, "disable_freeu") # mypy doesn't pick up this attribute? if did_apply_freeu: unet.disable_freeu() -class TextualInversionModel: - embedding: torch.Tensor # [n, 768]|[n, 1280] - embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models - - @classmethod - def from_checkpoint( - cls, - file_path: Union[str, Path], - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - if not isinstance(file_path, Path): - file_path = Path(file_path) - - result = cls() # TODO: - - if file_path.suffix == ".safetensors": - state_dict = load_file(file_path.absolute().as_posix(), device="cpu") - else: - state_dict = torch.load(file_path, map_location="cpu") - - # both v1 and v2 format embeddings - # difference mostly in metadata - if "string_to_param" in state_dict: - if len(state_dict["string_to_param"]) > 1: - print( - f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first', - " token will be used.", - ) - - result.embedding = next(iter(state_dict["string_to_param"].values())) - - # v3 (easynegative) - elif "emb_params" in state_dict: - result.embedding = state_dict["emb_params"] - - # v5(sdxl safetensors file) - elif "clip_g" in state_dict and "clip_l" in state_dict: - result.embedding = state_dict["clip_g"] - result.embedding_2 = state_dict["clip_l"] - - # v4(diffusers bin files) - else: - result.embedding = next(iter(state_dict.values())) - - if len(result.embedding.shape) == 1: - result.embedding = result.embedding.unsqueeze(0) - - if not isinstance(result.embedding, torch.Tensor): - raise ValueError(f"Invalid embeddings file: {file_path.name}") - - return result - - -class TextualInversionManager(BaseTextualInversionManager): - pad_tokens: Dict[int, List[int]] - tokenizer: CLIPTokenizer - - def __init__(self, tokenizer: CLIPTokenizer): - self.pad_tokens = {} - self.tokenizer = tokenizer - - def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]: - if len(self.pad_tokens) == 0: - return token_ids - - if token_ids[0] == self.tokenizer.bos_token_id: - raise ValueError("token_ids must not start with bos_token_id") - if token_ids[-1] == self.tokenizer.eos_token_id: - raise ValueError("token_ids must not end with eos_token_id") - - new_token_ids = [] - for token_id in token_ids: - new_token_ids.append(token_id) - if token_id in self.pad_tokens: - new_token_ids.extend(self.pad_tokens[token_id]) - - # Do not exceed the max model input size - # The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(), - # which first removes and then adds back the start and end tokens. - max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2 - if len(new_token_ids) > max_length: - new_token_ids = new_token_ids[0:max_length] - - return new_token_ids - - class ONNXModelPatcher: - from diffusers import OnnxRuntimeModel - - from .models.base import IAIOnnxRuntimeModel - @classmethod @contextmanager def apply_lora_unet( cls, unet: OnnxRuntimeModel, - loras: List[Tuple[LoRAModel, float]], - ): + loras: Iterator[Tuple[LoRAModelRaw, float]], + ) -> None: with cls.apply_lora(unet, loras, "lora_unet_"): yield @@ -398,8 +317,8 @@ def apply_lora_unet( def apply_lora_text_encoder( cls, text_encoder: OnnxRuntimeModel, - loras: List[Tuple[LoRAModel, float]], - ): + loras: List[Tuple[LoRAModelRaw, float]], + ) -> None: with cls.apply_lora(text_encoder, loras, "lora_te_"): yield @@ -410,9 +329,9 @@ def apply_lora_text_encoder( def apply_lora( cls, model: IAIOnnxRuntimeModel, - loras: List[Tuple[LoRAModel, float]], + loras: List[Tuple[LoRAModelRaw, float]], prefix: str, - ): + ) -> None: from .models.base import IAIOnnxRuntimeModel if not isinstance(model, IAIOnnxRuntimeModel): @@ -421,7 +340,7 @@ def apply_lora( orig_weights = {} try: - blended_loras = {} + blended_loras: Dict[str, torch.Tensor] = {} for lora, lora_weight in loras: for layer_key, layer in lora.layers.items(): @@ -432,7 +351,7 @@ def apply_lora( layer_key = layer_key.replace(prefix, "") # TODO: rewrite to pass original tensor weight(required by ia3) layer_weight = layer.get_weight(None).detach().cpu().numpy() * lora_weight - if layer_key is blended_loras: + if layer_key in blended_loras: blended_loras[layer_key] += layer_weight else: blended_loras[layer_key] = layer_weight @@ -499,7 +418,7 @@ def apply_ti( tokenizer: CLIPTokenizer, text_encoder: IAIOnnxRuntimeModel, ti_list: List[Tuple[str, Any]], - ) -> Tuple[CLIPTokenizer, TextualInversionManager]: + ) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]: from .models.base import IAIOnnxRuntimeModel if not isinstance(text_encoder, IAIOnnxRuntimeModel): @@ -517,7 +436,7 @@ def apply_ti( ti_tokenizer = pickle.loads(pickle.dumps(tokenizer)) ti_manager = TextualInversionManager(ti_tokenizer) - def _get_trigger(ti_name, index): + def _get_trigger(ti_name: str, index: int) -> str: trigger = ti_name if index > 0: trigger += f"-!pad-{i}" diff --git a/invokeai/backend/onnx/onnx_runtime.py b/invokeai/backend/onnx/onnx_runtime.py new file mode 100644 index 00000000000..8916865dd52 --- /dev/null +++ b/invokeai/backend/onnx/onnx_runtime.py @@ -0,0 +1,218 @@ +# Copyright (c) 2024 The InvokeAI Development Team +import os +import sys +from pathlib import Path +from typing import Any, List, Optional, Tuple, Union + +import numpy as np +import onnx +from onnx import numpy_helper +from onnxruntime import InferenceSession, SessionOptions, get_available_providers + +from ..raw_model import RawModel + +ONNX_WEIGHTS_NAME = "model.onnx" + + +# NOTE FROM LS: This was copied from Stalker's original implementation. +# I have not yet gone through and fixed all the type hints +class IAIOnnxRuntimeModel(RawModel): + class _tensor_access: + def __init__(self, model): # type: ignore + self.model = model + self.indexes = {} + for idx, obj in enumerate(self.model.proto.graph.initializer): + self.indexes[obj.name] = idx + + def __getitem__(self, key: str): # type: ignore + value = self.model.proto.graph.initializer[self.indexes[key]] + return numpy_helper.to_array(value) + + def __setitem__(self, key: str, value: np.ndarray): # type: ignore + new_node = numpy_helper.from_array(value) + # set_external_data(new_node, location="in-memory-location") + new_node.name = key + # new_node.ClearField("raw_data") + del self.model.proto.graph.initializer[self.indexes[key]] + self.model.proto.graph.initializer.insert(self.indexes[key], new_node) + # self.model.data[key] = OrtValue.ortvalue_from_numpy(value) + + # __delitem__ + + def __contains__(self, key: str) -> bool: + return self.indexes[key] in self.model.proto.graph.initializer + + def items(self) -> List[Tuple[str, Any]]: # fixme + raise NotImplementedError("tensor.items") + # return [(obj.name, obj) for obj in self.raw_proto] + + def keys(self) -> List[str]: + return list(self.indexes.keys()) + + def values(self) -> List[Any]: # fixme + raise NotImplementedError("tensor.values") + # return [obj for obj in self.raw_proto] + + def size(self) -> int: + bytesSum = 0 + for node in self.model.proto.graph.initializer: + bytesSum += sys.getsizeof(node.raw_data) + return bytesSum + + class _access_helper: + def __init__(self, raw_proto): # type: ignore + self.indexes = {} + self.raw_proto = raw_proto + for idx, obj in enumerate(raw_proto): + self.indexes[obj.name] = idx + + def __getitem__(self, key: str): # type: ignore + return self.raw_proto[self.indexes[key]] + + def __setitem__(self, key: str, value): # type: ignore + index = self.indexes[key] + del self.raw_proto[index] + self.raw_proto.insert(index, value) + + # __delitem__ + + def __contains__(self, key: str) -> bool: + return key in self.indexes + + def items(self) -> List[Tuple[str, Any]]: + return [(obj.name, obj) for obj in self.raw_proto] + + def keys(self) -> List[str]: + return list(self.indexes.keys()) + + def values(self) -> List[Any]: # fixme + return list(self.raw_proto) + + def __init__(self, model_path: str, provider: Optional[str]): + self.path = model_path + self.session = None + self.provider = provider + """ + self.data_path = self.path + "_data" + if not os.path.exists(self.data_path): + print(f"Moving model tensors to separate file: {self.data_path}") + tmp_proto = onnx.load(model_path, load_external_data=True) + onnx.save_model(tmp_proto, self.path, save_as_external_data=True, all_tensors_to_one_file=True, location=os.path.basename(self.data_path), size_threshold=1024, convert_attribute=False) + del tmp_proto + gc.collect() + + self.proto = onnx.load(model_path, load_external_data=False) + """ + + self.proto = onnx.load(model_path, load_external_data=True) + # self.data = dict() + # for tensor in self.proto.graph.initializer: + # name = tensor.name + + # if tensor.HasField("raw_data"): + # npt = numpy_helper.to_array(tensor) + # orv = OrtValue.ortvalue_from_numpy(npt) + # # self.data[name] = orv + # # set_external_data(tensor, location="in-memory-location") + # tensor.name = name + # # tensor.ClearField("raw_data") + + self.nodes = self._access_helper(self.proto.graph.node) # type: ignore + # self.initializers = self._access_helper(self.proto.graph.initializer) + # print(self.proto.graph.input) + # print(self.proto.graph.initializer) + + self.tensors = self._tensor_access(self) # type: ignore + + # TODO: integrate with model manager/cache + def create_session(self, height=None, width=None): + if self.session is None or self.session_width != width or self.session_height != height: + # onnx.save(self.proto, "tmp.onnx") + # onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False) + # TODO: something to be able to get weight when they already moved outside of model proto + # (trimmed_model, external_data) = buffer_external_data_tensors(self.proto) + sess = SessionOptions() + # self._external_data.update(**external_data) + # sess.add_external_initializers(list(self.data.keys()), list(self.data.values())) + # sess.enable_profiling = True + + # sess.intra_op_num_threads = 1 + # sess.inter_op_num_threads = 1 + # sess.execution_mode = ExecutionMode.ORT_SEQUENTIAL + # sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL + # sess.enable_cpu_mem_arena = True + # sess.enable_mem_pattern = True + # sess.add_session_config_entry("session.intra_op.use_xnnpack_threadpool", "1") ########### It's the key code + self.session_height = height + self.session_width = width + if height and width: + sess.add_free_dimension_override_by_name("unet_sample_batch", 2) + sess.add_free_dimension_override_by_name("unet_sample_channels", 4) + sess.add_free_dimension_override_by_name("unet_hidden_batch", 2) + sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77) + sess.add_free_dimension_override_by_name("unet_sample_height", self.session_height) + sess.add_free_dimension_override_by_name("unet_sample_width", self.session_width) + sess.add_free_dimension_override_by_name("unet_time_batch", 1) + providers = [] + if self.provider: + providers.append(self.provider) + else: + providers = get_available_providers() + if "TensorrtExecutionProvider" in providers: + providers.remove("TensorrtExecutionProvider") + try: + self.session = InferenceSession(self.proto.SerializeToString(), providers=providers, sess_options=sess) + except Exception as e: + raise e + # self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options) + # self.io_binding = self.session.io_binding() + + def release_session(self): + self.session = None + import gc + + gc.collect() + return + + def __call__(self, **kwargs): + if self.session is None: + raise Exception("You should call create_session before running model") + + inputs = {k: np.array(v) for k, v in kwargs.items()} + # output_names = self.session.get_outputs() + # for k in inputs: + # self.io_binding.bind_cpu_input(k, inputs[k]) + # for name in output_names: + # self.io_binding.bind_output(name.name) + # self.session.run_with_iobinding(self.io_binding, None) + # return self.io_binding.copy_outputs_to_cpu() + return self.session.run(None, inputs) + + # compatability with diffusers load code + @classmethod + def from_pretrained( + cls, + model_id: Union[str, Path], + subfolder: Optional[Union[str, Path]] = None, + file_name: Optional[str] = None, + provider: Optional[str] = None, + sess_options: Optional["SessionOptions"] = None, + **kwargs: Any, + ) -> Any: # fixme + file_name = file_name or ONNX_WEIGHTS_NAME + + if os.path.isdir(model_id): + model_path = model_id + if subfolder is not None: + model_path = os.path.join(model_path, subfolder) + model_path = os.path.join(model_path, file_name) + + else: + model_path = model_id + + # load model from local directory + if not os.path.isfile(model_path): + raise Exception(f"Model not found: {model_path}") + + # TODO: session options + return cls(str(model_path), provider=provider) diff --git a/invokeai/backend/raw_model.py b/invokeai/backend/raw_model.py new file mode 100644 index 00000000000..d0dc50c4560 --- /dev/null +++ b/invokeai/backend/raw_model.py @@ -0,0 +1,15 @@ +"""Base class for 'Raw' models. + +The RawModel class is the base class of LoRAModelRaw and TextualInversionModelRaw, +and is used for type checking of calls to the model patcher. Its main purpose +is to avoid a circular import issues when lora.py tries to import BaseModelType +from invokeai.backend.model_manager.config, and the latter tries to import LoRAModelRaw +from lora.py. + +The term 'raw' was introduced to describe a wrapper around a torch.nn.Module +that adds additional methods and attributes. +""" + + +class RawModel: + """Base class for 'Raw' model wrappers.""" diff --git a/invokeai/backend/stable_diffusion/__init__.py b/invokeai/backend/stable_diffusion/__init__.py index 212045f81b8..75e6aa0a5de 100644 --- a/invokeai/backend/stable_diffusion/__init__.py +++ b/invokeai/backend/stable_diffusion/__init__.py @@ -4,3 +4,12 @@ from .diffusers_pipeline import PipelineIntermediateState, StableDiffusionGeneratorPipeline # noqa: F401 from .diffusion import InvokeAIDiffuserComponent # noqa: F401 from .diffusion.cross_attention_map_saving import AttentionMapSaver # noqa: F401 +from .seamless import set_seamless # noqa: F401 + +__all__ = [ + "PipelineIntermediateState", + "StableDiffusionGeneratorPipeline", + "InvokeAIDiffuserComponent", + "AttentionMapSaver", + "set_seamless", +] diff --git a/invokeai/backend/stable_diffusion/schedulers/__init__.py b/invokeai/backend/stable_diffusion/schedulers/__init__.py index a4e9dbf9dad..0b780d3ee27 100644 --- a/invokeai/backend/stable_diffusion/schedulers/__init__.py +++ b/invokeai/backend/stable_diffusion/schedulers/__init__.py @@ -1 +1,3 @@ from .schedulers import SCHEDULER_MAP # noqa: F401 + +__all__ = ["SCHEDULER_MAP"] diff --git a/invokeai/backend/stable_diffusion/seamless.py b/invokeai/backend/stable_diffusion/seamless.py new file mode 100644 index 00000000000..fb9112b56dc --- /dev/null +++ b/invokeai/backend/stable_diffusion/seamless.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from contextlib import contextmanager +from typing import Callable, List, Union + +import torch.nn as nn +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel + + +def _conv_forward_asymmetric(self, input, weight, bias): + """ + Patch for Conv2d._conv_forward that supports asymmetric padding + """ + working = nn.functional.pad(input, self.asymmetric_padding["x"], mode=self.asymmetric_padding_mode["x"]) + working = nn.functional.pad(working, self.asymmetric_padding["y"], mode=self.asymmetric_padding_mode["y"]) + return nn.functional.conv2d( + working, + weight, + bias, + self.stride, + nn.modules.utils._pair(0), + self.dilation, + self.groups, + ) + + +@contextmanager +def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]): + # Callable: (input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor + to_restore: list[tuple[nn.Conv2d | nn.ConvTranspose2d, Callable]] = [] + try: + # Hard coded to skip down block layers, allowing for seamless tiling at the expense of prompt adherence + skipped_layers = 1 + for m_name, m in model.named_modules(): + if not isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): + continue + + if isinstance(model, UNet2DConditionModel) and m_name.startswith("down_blocks.") and ".resnets." in m_name: + # down_blocks.1.resnets.1.conv1 + _, block_num, _, resnet_num, submodule_name = m_name.split(".") + block_num = int(block_num) + resnet_num = int(resnet_num) + + if block_num >= len(model.down_blocks) - skipped_layers: + continue + + # Skip the second resnet (could be configurable) + if resnet_num > 0: + continue + + # Skip Conv2d layers (could be configurable) + if submodule_name == "conv2": + continue + + m.asymmetric_padding_mode = {} + m.asymmetric_padding = {} + m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant" + m.asymmetric_padding["x"] = ( + m._reversed_padding_repeated_twice[0], + m._reversed_padding_repeated_twice[1], + 0, + 0, + ) + m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant" + m.asymmetric_padding["y"] = ( + 0, + 0, + m._reversed_padding_repeated_twice[2], + m._reversed_padding_repeated_twice[3], + ) + + to_restore.append((m, m._conv_forward)) + m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d) + + yield + + finally: + for module, orig_conv_forward in to_restore: + module._conv_forward = orig_conv_forward + if hasattr(module, "asymmetric_padding_mode"): + del module.asymmetric_padding_mode + if hasattr(module, "asymmetric_padding"): + del module.asymmetric_padding diff --git a/invokeai/backend/textual_inversion.py b/invokeai/backend/textual_inversion.py new file mode 100644 index 00000000000..f7390979bbc --- /dev/null +++ b/invokeai/backend/textual_inversion.py @@ -0,0 +1,100 @@ +"""Textual Inversion wrapper class.""" + +from pathlib import Path +from typing import Dict, List, Optional, Union + +import torch +from compel.embeddings_provider import BaseTextualInversionManager +from safetensors.torch import load_file +from transformers import CLIPTokenizer +from typing_extensions import Self + +from .raw_model import RawModel + + +class TextualInversionModelRaw(RawModel): + embedding: torch.Tensor # [n, 768]|[n, 1280] + embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models + + @classmethod + def from_checkpoint( + cls, + file_path: Union[str, Path], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> Self: + if not isinstance(file_path, Path): + file_path = Path(file_path) + + result = cls() # TODO: + + if file_path.suffix == ".safetensors": + state_dict = load_file(file_path.absolute().as_posix(), device="cpu") + else: + state_dict = torch.load(file_path, map_location="cpu") + + # both v1 and v2 format embeddings + # difference mostly in metadata + if "string_to_param" in state_dict: + if len(state_dict["string_to_param"]) > 1: + print( + f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first', + " token will be used.", + ) + + result.embedding = next(iter(state_dict["string_to_param"].values())) + + # v3 (easynegative) + elif "emb_params" in state_dict: + result.embedding = state_dict["emb_params"] + + # v5(sdxl safetensors file) + elif "clip_g" in state_dict and "clip_l" in state_dict: + result.embedding = state_dict["clip_g"] + result.embedding_2 = state_dict["clip_l"] + + # v4(diffusers bin files) + else: + result.embedding = next(iter(state_dict.values())) + + if len(result.embedding.shape) == 1: + result.embedding = result.embedding.unsqueeze(0) + + if not isinstance(result.embedding, torch.Tensor): + raise ValueError(f"Invalid embeddings file: {file_path.name}") + + return result + + +# no type hints for BaseTextualInversionManager? +class TextualInversionManager(BaseTextualInversionManager): # type: ignore + pad_tokens: Dict[int, List[int]] + tokenizer: CLIPTokenizer + + def __init__(self, tokenizer: CLIPTokenizer): + self.pad_tokens = {} + self.tokenizer = tokenizer + + def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]: + if len(self.pad_tokens) == 0: + return token_ids + + if token_ids[0] == self.tokenizer.bos_token_id: + raise ValueError("token_ids must not start with bos_token_id") + if token_ids[-1] == self.tokenizer.eos_token_id: + raise ValueError("token_ids must not end with eos_token_id") + + new_token_ids = [] + for token_id in token_ids: + new_token_ids.append(token_id) + if token_id in self.pad_tokens: + new_token_ids.extend(self.pad_tokens[token_id]) + + # Do not exceed the max model input size + # The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(), + # which first removes and then adds back the start and end tokens. + max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2 + if len(new_token_ids) > max_length: + new_token_ids = new_token_ids[0:max_length] + + return new_token_ids diff --git a/invokeai/backend/util/__init__.py b/invokeai/backend/util/__init__.py index 87ae1480f54..7b48f0364ea 100644 --- a/invokeai/backend/util/__init__.py +++ b/invokeai/backend/util/__init__.py @@ -12,6 +12,22 @@ torch_dtype, ) from .logging import InvokeAILogger -from .util import Chdir, ask_user, download_with_resume, instantiate_from_config, url_attachment_name # noqa: F401 +from .util import ( # TO DO: Clean this up; remove the unused symbols + GIG, + Chdir, + ask_user, # noqa + directory_size, + download_with_resume, + instantiate_from_config, # noqa + url_attachment_name, # noqa +) -__all__ = ["Chdir", "InvokeAILogger", "choose_precision", "choose_torch_device"] +__all__ = [ + "GIG", + "directory_size", + "Chdir", + "download_with_resume", + "InvokeAILogger", + "choose_precision", + "choose_torch_device", +] diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index d6d3ad727f7..a83d1045f70 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -1,7 +1,7 @@ from __future__ import annotations from contextlib import nullcontext -from typing import Union +from typing import Literal, Optional, Union import torch from torch import autocast @@ -29,12 +29,19 @@ def choose_torch_device() -> torch.device: return torch.device(config.device) -def choose_precision(device: torch.device) -> str: - """Returns an appropriate precision for the given torch device""" +# We are in transition here from using a single global AppConfig to allowing multiple +# configurations. It is strongly recommended to pass the app_config to this function. +def choose_precision( + device: torch.device, app_config: Optional[InvokeAIAppConfig] = None +) -> Literal["float32", "float16", "bfloat16"]: + """Return an appropriate precision for the given torch device.""" + app_config = app_config or config if device.type == "cuda": device_name = torch.cuda.get_device_name(device) if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name): - if config.precision == "bfloat16": + if app_config.precision == "float32": + return "float32" + elif app_config.precision == "bfloat16": return "bfloat16" else: return "float16" @@ -43,8 +50,14 @@ def choose_precision(device: torch.device) -> str: return "float32" -def torch_dtype(device: torch.device) -> torch.dtype: - precision = choose_precision(device) +# We are in transition here from using a single global AppConfig to allowing multiple +# configurations. It is strongly recommended to pass the app_config to this function. +def torch_dtype( + device: Optional[torch.device] = None, + app_config: Optional[InvokeAIAppConfig] = None, +) -> torch.dtype: + device = device or choose_torch_device() + precision = choose_precision(device, app_config) if precision == "float16": return torch.float16 if precision == "bfloat16": diff --git a/invokeai/backend/util/silence_warnings.py b/invokeai/backend/util/silence_warnings.py new file mode 100644 index 00000000000..068b605da97 --- /dev/null +++ b/invokeai/backend/util/silence_warnings.py @@ -0,0 +1,28 @@ +"""Context class to silence transformers and diffusers warnings.""" +import warnings +from typing import Any + +from diffusers import logging as diffusers_logging +from transformers import logging as transformers_logging + + +class SilenceWarnings(object): + """Use in context to temporarily turn off warnings from transformers & diffusers modules. + + with SilenceWarnings(): + # do something + """ + + def __init__(self) -> None: + self.transformers_verbosity = transformers_logging.get_verbosity() + self.diffusers_verbosity = diffusers_logging.get_verbosity() + + def __enter__(self) -> None: + transformers_logging.set_verbosity_error() + diffusers_logging.set_verbosity_error() + warnings.simplefilter("ignore") + + def __exit__(self, *args: Any) -> None: + transformers_logging.set_verbosity(self.transformers_verbosity) + diffusers_logging.set_verbosity(self.diffusers_verbosity) + warnings.simplefilter("default") diff --git a/invokeai/backend/util/test_utils.py b/invokeai/backend/util/test_utils.py index 09b9de9e984..72bfc3c6a76 100644 --- a/invokeai/backend/util/test_utils.py +++ b/invokeai/backend/util/test_utils.py @@ -5,10 +5,9 @@ import pytest import torch -from invokeai.app.services.config.config_default import InvokeAIAppConfig -from invokeai.backend.install.model_install_backend import ModelInstall -from invokeai.backend.model_management.model_manager import ModelInfo -from invokeai.backend.model_management.models.base import BaseModelType, ModelNotFoundException, ModelType, SubModelType +from invokeai.app.services.model_manager import ModelManagerServiceBase +from invokeai.app.services.model_records import UnknownModelException +from invokeai.backend.model_manager import BaseModelType, LoadedModel, ModelType, SubModelType @pytest.fixture(scope="session") @@ -16,31 +15,20 @@ def torch_device(): return "cuda" if torch.cuda.is_available() else "cpu" -@pytest.fixture(scope="module") -def model_installer(): - """A global ModelInstall pytest fixture to be used by many tests.""" - # HACK(ryand): InvokeAIAppConfig.get_config() returns a singleton config object. This can lead to weird interactions - # between tests that need to alter the config. For example, some tests change the 'root' directory in the config, - # which can cause `install_and_load_model(...)` to re-download the model unnecessarily. As a temporary workaround, - # we pass a kwarg to get_config, which causes the config to be re-loaded. To fix this properly, we should stop using - # a singleton. - return ModelInstall(InvokeAIAppConfig.get_config(log_level="info")) - - def install_and_load_model( - model_installer: ModelInstall, + model_manager: ModelManagerServiceBase, model_path_id_or_url: Union[str, Path], model_name: str, base_model: BaseModelType, model_type: ModelType, submodel_type: Optional[SubModelType] = None, -) -> ModelInfo: - """Install a model if it is not already installed, then get the ModelInfo for that model. +) -> LoadedModel: + """Install a model if it is not already installed, then get the LoadedModel for that model. This is intended as a utility function for tests. Args: - model_installer (ModelInstall): The model installer. + mm2_model_manager (ModelManagerServiceBase): The model manager model_path_id_or_url (Union[str, Path]): The path, HF ID, URL, etc. where the model can be installed from if it is not already installed. model_name (str): The model name, forwarded to ModelManager.get_model(...). @@ -51,16 +39,23 @@ def install_and_load_model( Returns: ModelInfo """ - # If the requested model is already installed, return its ModelInfo. - with contextlib.suppress(ModelNotFoundException): - return model_installer.mgr.get_model(model_name, base_model, model_type, submodel_type) + # If the requested model is already installed, return its LoadedModel + with contextlib.suppress(UnknownModelException): + # TODO: Replace with wrapper call + loaded_model: LoadedModel = model_manager.load.load_model_by_attr( + name=model_name, base=base_model, type=model_type + ) + return loaded_model # Install the requested model. - model_installer.heuristic_import(model_path_id_or_url) + job = model_manager.install.heuristic_import(model_path_id_or_url) + model_manager.install.wait_for_job(job, timeout=10) + assert job.is_complete try: - return model_installer.mgr.get_model(model_name, base_model, model_type, submodel_type) - except ModelNotFoundException as e: + loaded_model = model_manager.load.load_by_config(job.config) + return loaded_model + except UnknownModelException as e: raise Exception( "Failed to get model info after installing it. There could be a mismatch between the requested model and" f" the installation id ('{model_path_id_or_url}'). Error: {e}" diff --git a/invokeai/backend/util/util.py b/invokeai/backend/util/util.py index 13751e27702..ae376b41b25 100644 --- a/invokeai/backend/util/util.py +++ b/invokeai/backend/util/util.py @@ -24,6 +24,22 @@ from .devices import torch_dtype +# actual size of a gig +GIG = 1073741824 + + +def directory_size(directory: Path) -> int: + """ + Return the aggregate size of all files in a directory (bytes). + """ + sum = 0 + for root, dirs, files in os.walk(directory): + for f in files: + sum += Path(root, f).stat().st_size + for d in dirs: + sum += Path(root, d).stat().st_size + return sum + def log_txt_as_img(wh, xc, size=10): # wh a tuple of (width, height) diff --git a/invokeai/configs/INITIAL_MODELS.yaml b/invokeai/configs/INITIAL_MODELS.yaml index c230665e3a6..ca2283ab811 100644 --- a/invokeai/configs/INITIAL_MODELS.yaml +++ b/invokeai/configs/INITIAL_MODELS.yaml @@ -1,153 +1,157 @@ # This file predefines a few models that the user may want to install. sd-1/main/stable-diffusion-v1-5: description: Stable Diffusion version 1.5 diffusers model (4.27 GB) - repo_id: runwayml/stable-diffusion-v1-5 + source: runwayml/stable-diffusion-v1-5 recommended: True default: True sd-1/main/stable-diffusion-v1-5-inpainting: description: RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB) - repo_id: runwayml/stable-diffusion-inpainting + source: runwayml/stable-diffusion-inpainting recommended: True sd-2/main/stable-diffusion-2-1: description: Stable Diffusion version 2.1 diffusers model, trained on 768 pixel images (5.21 GB) - repo_id: stabilityai/stable-diffusion-2-1 + source: stabilityai/stable-diffusion-2-1 recommended: False sd-2/main/stable-diffusion-2-inpainting: description: Stable Diffusion version 2.0 inpainting model (5.21 GB) - repo_id: stabilityai/stable-diffusion-2-inpainting + source: stabilityai/stable-diffusion-2-inpainting recommended: False sdxl/main/stable-diffusion-xl-base-1-0: description: Stable Diffusion XL base model (12 GB) - repo_id: stabilityai/stable-diffusion-xl-base-1.0 + source: stabilityai/stable-diffusion-xl-base-1.0 recommended: True sdxl-refiner/main/stable-diffusion-xl-refiner-1-0: description: Stable Diffusion XL refiner model (12 GB) - repo_id: stabilityai/stable-diffusion-xl-refiner-1.0 + source: stabilityai/stable-diffusion-xl-refiner-1.0 recommended: False -sdxl/vae/sdxl-1-0-vae-fix: - description: Fine tuned version of the SDXL-1.0 VAE - repo_id: madebyollin/sdxl-vae-fp16-fix +sdxl/vae/sdxl-vae-fp16-fix: + description: Version of the SDXL-1.0 VAE that works in half precision mode + source: madebyollin/sdxl-vae-fp16-fix recommended: True sd-1/main/Analog-Diffusion: description: An SD-1.5 model trained on diverse analog photographs (2.13 GB) - repo_id: wavymulder/Analog-Diffusion + source: wavymulder/Analog-Diffusion recommended: False -sd-1/main/Deliberate_v5: +sd-1/main/Deliberate: description: Versatile model that produces detailed images up to 768px (4.27 GB) - path: https://huggingface.co/XpucT/Deliberate/resolve/main/Deliberate_v5.safetensors + source: XpucT/Deliberate recommended: False sd-1/main/Dungeons-and-Diffusion: description: Dungeons & Dragons characters (2.13 GB) - repo_id: 0xJustin/Dungeons-and-Diffusion + source: 0xJustin/Dungeons-and-Diffusion recommended: False sd-1/main/dreamlike-photoreal-2: description: A photorealistic model trained on 768 pixel images based on SD 1.5 (2.13 GB) - repo_id: dreamlike-art/dreamlike-photoreal-2.0 + source: dreamlike-art/dreamlike-photoreal-2.0 recommended: False sd-1/main/Inkpunk-Diffusion: description: Stylized illustrations inspired by Gorillaz, FLCL and Shinkawa; prompt with "nvinkpunk" (4.27 GB) - repo_id: Envvi/Inkpunk-Diffusion + source: Envvi/Inkpunk-Diffusion recommended: False sd-1/main/openjourney: description: An SD 1.5 model fine tuned on Midjourney; prompt with "mdjrny-v4 style" (2.13 GB) - repo_id: prompthero/openjourney + source: prompthero/openjourney recommended: False sd-1/main/seek.art_MEGA: - repo_id: coreco/seek.art_MEGA + source: coreco/seek.art_MEGA description: A general use SD-1.5 "anything" model that supports multiple styles (2.1 GB) recommended: False sd-1/main/trinart_stable_diffusion_v2: description: An SD-1.5 model finetuned with ~40K assorted high resolution manga/anime-style images (2.13 GB) - repo_id: naclbit/trinart_stable_diffusion_v2 + source: naclbit/trinart_stable_diffusion_v2 recommended: False sd-1/controlnet/qrcode_monster: - repo_id: monster-labs/control_v1p_sd15_qrcode_monster + source: monster-labs/control_v1p_sd15_qrcode_monster subfolder: v2 sd-1/controlnet/canny: - repo_id: lllyasviel/control_v11p_sd15_canny + source: lllyasviel/control_v11p_sd15_canny recommended: True sd-1/controlnet/inpaint: - repo_id: lllyasviel/control_v11p_sd15_inpaint + source: lllyasviel/control_v11p_sd15_inpaint sd-1/controlnet/mlsd: - repo_id: lllyasviel/control_v11p_sd15_mlsd + source: lllyasviel/control_v11p_sd15_mlsd sd-1/controlnet/depth: - repo_id: lllyasviel/control_v11f1p_sd15_depth + source: lllyasviel/control_v11f1p_sd15_depth recommended: True sd-1/controlnet/normal_bae: - repo_id: lllyasviel/control_v11p_sd15_normalbae + source: lllyasviel/control_v11p_sd15_normalbae sd-1/controlnet/seg: - repo_id: lllyasviel/control_v11p_sd15_seg + source: lllyasviel/control_v11p_sd15_seg sd-1/controlnet/lineart: - repo_id: lllyasviel/control_v11p_sd15_lineart + source: lllyasviel/control_v11p_sd15_lineart recommended: True sd-1/controlnet/lineart_anime: - repo_id: lllyasviel/control_v11p_sd15s2_lineart_anime + source: lllyasviel/control_v11p_sd15s2_lineart_anime sd-1/controlnet/openpose: - repo_id: lllyasviel/control_v11p_sd15_openpose + source: lllyasviel/control_v11p_sd15_openpose recommended: True sd-1/controlnet/scribble: - repo_id: lllyasviel/control_v11p_sd15_scribble + source: lllyasviel/control_v11p_sd15_scribble recommended: False sd-1/controlnet/softedge: - repo_id: lllyasviel/control_v11p_sd15_softedge + source: lllyasviel/control_v11p_sd15_softedge sd-1/controlnet/shuffle: - repo_id: lllyasviel/control_v11e_sd15_shuffle + source: lllyasviel/control_v11e_sd15_shuffle sd-1/controlnet/tile: - repo_id: lllyasviel/control_v11f1e_sd15_tile + source: lllyasviel/control_v11f1e_sd15_tile sd-1/controlnet/ip2p: - repo_id: lllyasviel/control_v11e_sd15_ip2p + source: lllyasviel/control_v11e_sd15_ip2p sd-1/t2i_adapter/canny-sd15: - repo_id: TencentARC/t2iadapter_canny_sd15v2 + source: TencentARC/t2iadapter_canny_sd15v2 sd-1/t2i_adapter/sketch-sd15: - repo_id: TencentARC/t2iadapter_sketch_sd15v2 + source: TencentARC/t2iadapter_sketch_sd15v2 sd-1/t2i_adapter/depth-sd15: - repo_id: TencentARC/t2iadapter_depth_sd15v2 + source: TencentARC/t2iadapter_depth_sd15v2 sd-1/t2i_adapter/zoedepth-sd15: - repo_id: TencentARC/t2iadapter_zoedepth_sd15v1 + source: TencentARC/t2iadapter_zoedepth_sd15v1 sdxl/t2i_adapter/canny-sdxl: - repo_id: TencentARC/t2i-adapter-canny-sdxl-1.0 + source: TencentARC/t2i-adapter-canny-sdxl-1.0 sdxl/t2i_adapter/zoedepth-sdxl: - repo_id: TencentARC/t2i-adapter-depth-zoe-sdxl-1.0 + source: TencentARC/t2i-adapter-depth-zoe-sdxl-1.0 sdxl/t2i_adapter/lineart-sdxl: - repo_id: TencentARC/t2i-adapter-lineart-sdxl-1.0 + source: TencentARC/t2i-adapter-lineart-sdxl-1.0 sdxl/t2i_adapter/sketch-sdxl: - repo_id: TencentARC/t2i-adapter-sketch-sdxl-1.0 + source: TencentARC/t2i-adapter-sketch-sdxl-1.0 sd-1/embedding/EasyNegative: - path: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors + source: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors recommended: True -sd-1/embedding/ahx-beta-453407d: - repo_id: sd-concepts-library/ahx-beta-453407d + description: A textual inversion to use in the negative prompt to reduce bad anatomy +sd-1/lora/FlatColor: + source: https://civitai.com/models/6433/loraflatcolor + recommended: True + description: A LoRA that generates scenery using solid blocks of color sd-1/lora/Ink scenery: - path: https://civitai.com/api/download/models/83390 + source: https://civitai.com/api/download/models/83390 + description: Generate india ink-like landscapes sd-1/ip_adapter/ip_adapter_sd15: - repo_id: InvokeAI/ip_adapter_sd15 + source: InvokeAI/ip_adapter_sd15 recommended: True requires: - InvokeAI/ip_adapter_sd_image_encoder description: IP-Adapter for SD 1.5 models sd-1/ip_adapter/ip_adapter_plus_sd15: - repo_id: InvokeAI/ip_adapter_plus_sd15 + source: InvokeAI/ip_adapter_plus_sd15 recommended: False requires: - InvokeAI/ip_adapter_sd_image_encoder description: Refined IP-Adapter for SD 1.5 models sd-1/ip_adapter/ip_adapter_plus_face_sd15: - repo_id: InvokeAI/ip_adapter_plus_face_sd15 + source: InvokeAI/ip_adapter_plus_face_sd15 recommended: False requires: - InvokeAI/ip_adapter_sd_image_encoder description: Refined IP-Adapter for SD 1.5 models, adapted for faces sdxl/ip_adapter/ip_adapter_sdxl: - repo_id: InvokeAI/ip_adapter_sdxl + source: InvokeAI/ip_adapter_sdxl recommended: False requires: - InvokeAI/ip_adapter_sdxl_image_encoder description: IP-Adapter for SDXL models any/clip_vision/ip_adapter_sd_image_encoder: - repo_id: InvokeAI/ip_adapter_sd_image_encoder + source: InvokeAI/ip_adapter_sd_image_encoder recommended: False description: Required model for using IP-Adapters with SD-1/2 models any/clip_vision/ip_adapter_sdxl_image_encoder: - repo_id: InvokeAI/ip_adapter_sdxl_image_encoder + source: InvokeAI/ip_adapter_sdxl_image_encoder recommended: False description: Required model for using IP-Adapters with SDXL models diff --git a/invokeai/configs/INITIAL_MODELS2.yaml b/invokeai/configs/INITIAL_MODELS2.yaml deleted file mode 100644 index ca2283ab811..00000000000 --- a/invokeai/configs/INITIAL_MODELS2.yaml +++ /dev/null @@ -1,157 +0,0 @@ -# This file predefines a few models that the user may want to install. -sd-1/main/stable-diffusion-v1-5: - description: Stable Diffusion version 1.5 diffusers model (4.27 GB) - source: runwayml/stable-diffusion-v1-5 - recommended: True - default: True -sd-1/main/stable-diffusion-v1-5-inpainting: - description: RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB) - source: runwayml/stable-diffusion-inpainting - recommended: True -sd-2/main/stable-diffusion-2-1: - description: Stable Diffusion version 2.1 diffusers model, trained on 768 pixel images (5.21 GB) - source: stabilityai/stable-diffusion-2-1 - recommended: False -sd-2/main/stable-diffusion-2-inpainting: - description: Stable Diffusion version 2.0 inpainting model (5.21 GB) - source: stabilityai/stable-diffusion-2-inpainting - recommended: False -sdxl/main/stable-diffusion-xl-base-1-0: - description: Stable Diffusion XL base model (12 GB) - source: stabilityai/stable-diffusion-xl-base-1.0 - recommended: True -sdxl-refiner/main/stable-diffusion-xl-refiner-1-0: - description: Stable Diffusion XL refiner model (12 GB) - source: stabilityai/stable-diffusion-xl-refiner-1.0 - recommended: False -sdxl/vae/sdxl-vae-fp16-fix: - description: Version of the SDXL-1.0 VAE that works in half precision mode - source: madebyollin/sdxl-vae-fp16-fix - recommended: True -sd-1/main/Analog-Diffusion: - description: An SD-1.5 model trained on diverse analog photographs (2.13 GB) - source: wavymulder/Analog-Diffusion - recommended: False -sd-1/main/Deliberate: - description: Versatile model that produces detailed images up to 768px (4.27 GB) - source: XpucT/Deliberate - recommended: False -sd-1/main/Dungeons-and-Diffusion: - description: Dungeons & Dragons characters (2.13 GB) - source: 0xJustin/Dungeons-and-Diffusion - recommended: False -sd-1/main/dreamlike-photoreal-2: - description: A photorealistic model trained on 768 pixel images based on SD 1.5 (2.13 GB) - source: dreamlike-art/dreamlike-photoreal-2.0 - recommended: False -sd-1/main/Inkpunk-Diffusion: - description: Stylized illustrations inspired by Gorillaz, FLCL and Shinkawa; prompt with "nvinkpunk" (4.27 GB) - source: Envvi/Inkpunk-Diffusion - recommended: False -sd-1/main/openjourney: - description: An SD 1.5 model fine tuned on Midjourney; prompt with "mdjrny-v4 style" (2.13 GB) - source: prompthero/openjourney - recommended: False -sd-1/main/seek.art_MEGA: - source: coreco/seek.art_MEGA - description: A general use SD-1.5 "anything" model that supports multiple styles (2.1 GB) - recommended: False -sd-1/main/trinart_stable_diffusion_v2: - description: An SD-1.5 model finetuned with ~40K assorted high resolution manga/anime-style images (2.13 GB) - source: naclbit/trinart_stable_diffusion_v2 - recommended: False -sd-1/controlnet/qrcode_monster: - source: monster-labs/control_v1p_sd15_qrcode_monster - subfolder: v2 -sd-1/controlnet/canny: - source: lllyasviel/control_v11p_sd15_canny - recommended: True -sd-1/controlnet/inpaint: - source: lllyasviel/control_v11p_sd15_inpaint -sd-1/controlnet/mlsd: - source: lllyasviel/control_v11p_sd15_mlsd -sd-1/controlnet/depth: - source: lllyasviel/control_v11f1p_sd15_depth - recommended: True -sd-1/controlnet/normal_bae: - source: lllyasviel/control_v11p_sd15_normalbae -sd-1/controlnet/seg: - source: lllyasviel/control_v11p_sd15_seg -sd-1/controlnet/lineart: - source: lllyasviel/control_v11p_sd15_lineart - recommended: True -sd-1/controlnet/lineart_anime: - source: lllyasviel/control_v11p_sd15s2_lineart_anime -sd-1/controlnet/openpose: - source: lllyasviel/control_v11p_sd15_openpose - recommended: True -sd-1/controlnet/scribble: - source: lllyasviel/control_v11p_sd15_scribble - recommended: False -sd-1/controlnet/softedge: - source: lllyasviel/control_v11p_sd15_softedge -sd-1/controlnet/shuffle: - source: lllyasviel/control_v11e_sd15_shuffle -sd-1/controlnet/tile: - source: lllyasviel/control_v11f1e_sd15_tile -sd-1/controlnet/ip2p: - source: lllyasviel/control_v11e_sd15_ip2p -sd-1/t2i_adapter/canny-sd15: - source: TencentARC/t2iadapter_canny_sd15v2 -sd-1/t2i_adapter/sketch-sd15: - source: TencentARC/t2iadapter_sketch_sd15v2 -sd-1/t2i_adapter/depth-sd15: - source: TencentARC/t2iadapter_depth_sd15v2 -sd-1/t2i_adapter/zoedepth-sd15: - source: TencentARC/t2iadapter_zoedepth_sd15v1 -sdxl/t2i_adapter/canny-sdxl: - source: TencentARC/t2i-adapter-canny-sdxl-1.0 -sdxl/t2i_adapter/zoedepth-sdxl: - source: TencentARC/t2i-adapter-depth-zoe-sdxl-1.0 -sdxl/t2i_adapter/lineart-sdxl: - source: TencentARC/t2i-adapter-lineart-sdxl-1.0 -sdxl/t2i_adapter/sketch-sdxl: - source: TencentARC/t2i-adapter-sketch-sdxl-1.0 -sd-1/embedding/EasyNegative: - source: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors - recommended: True - description: A textual inversion to use in the negative prompt to reduce bad anatomy -sd-1/lora/FlatColor: - source: https://civitai.com/models/6433/loraflatcolor - recommended: True - description: A LoRA that generates scenery using solid blocks of color -sd-1/lora/Ink scenery: - source: https://civitai.com/api/download/models/83390 - description: Generate india ink-like landscapes -sd-1/ip_adapter/ip_adapter_sd15: - source: InvokeAI/ip_adapter_sd15 - recommended: True - requires: - - InvokeAI/ip_adapter_sd_image_encoder - description: IP-Adapter for SD 1.5 models -sd-1/ip_adapter/ip_adapter_plus_sd15: - source: InvokeAI/ip_adapter_plus_sd15 - recommended: False - requires: - - InvokeAI/ip_adapter_sd_image_encoder - description: Refined IP-Adapter for SD 1.5 models -sd-1/ip_adapter/ip_adapter_plus_face_sd15: - source: InvokeAI/ip_adapter_plus_face_sd15 - recommended: False - requires: - - InvokeAI/ip_adapter_sd_image_encoder - description: Refined IP-Adapter for SD 1.5 models, adapted for faces -sdxl/ip_adapter/ip_adapter_sdxl: - source: InvokeAI/ip_adapter_sdxl - recommended: False - requires: - - InvokeAI/ip_adapter_sdxl_image_encoder - description: IP-Adapter for SDXL models -any/clip_vision/ip_adapter_sd_image_encoder: - source: InvokeAI/ip_adapter_sd_image_encoder - recommended: False - description: Required model for using IP-Adapters with SD-1/2 models -any/clip_vision/ip_adapter_sdxl_image_encoder: - source: InvokeAI/ip_adapter_sdxl_image_encoder - recommended: False - description: Required model for using IP-Adapters with SDXL models diff --git a/invokeai/configs/models.yaml.example b/invokeai/configs/models.yaml.example deleted file mode 100644 index 98f8f77e62c..00000000000 --- a/invokeai/configs/models.yaml.example +++ /dev/null @@ -1,47 +0,0 @@ -# This file describes the alternative machine learning models -# available to InvokeAI script. -# -# To add a new model, follow the examples below. Each -# model requires a model config file, a weights file, -# and the width and height of the images it -# was trained on. -diffusers-1.4: - description: 🤗🧨 Stable Diffusion v1.4 - format: diffusers - repo_id: CompVis/stable-diffusion-v1-4 -diffusers-1.5: - description: 🤗🧨 Stable Diffusion v1.5 - format: diffusers - repo_id: runwayml/stable-diffusion-v1-5 - default: true -diffusers-1.5+mse: - description: 🤗🧨 Stable Diffusion v1.5 + MSE-finetuned VAE - format: diffusers - repo_id: runwayml/stable-diffusion-v1-5 - vae: - repo_id: stabilityai/sd-vae-ft-mse -diffusers-inpainting-1.5: - description: 🤗🧨 inpainting for Stable Diffusion v1.5 - format: diffusers - repo_id: runwayml/stable-diffusion-inpainting -stable-diffusion-1.5: - description: The newest Stable Diffusion version 1.5 weight file (4.27 GB) - weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt - config: configs/stable-diffusion/v1-inference.yaml - width: 512 - height: 512 - vae: ./models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt -stable-diffusion-1.4: - description: Stable Diffusion inference model version 1.4 - config: configs/stable-diffusion/v1-inference.yaml - weights: models/ldm/stable-diffusion-v1/sd-v1-4.ckpt - vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt - width: 512 - height: 512 -inpainting-1.5: - weights: models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt - config: configs/stable-diffusion/v1-inpainting-inference.yaml - vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt - description: RunwayML SD 1.5 model optimized for inpainting - width: 512 - height: 512 diff --git a/invokeai/frontend/install/model_install.py b/invokeai/frontend/install/model_install.py index e23538ffd66..20b630dfc62 100644 --- a/invokeai/frontend/install/model_install.py +++ b/invokeai/frontend/install/model_install.py @@ -6,47 +6,45 @@ """ This is the npyscreen frontend to the model installation application. -The work is actually done in backend code in model_install_backend.py. +It is currently named model_install2.py, but will ultimately replace model_install.py. """ import argparse import curses -import logging import sys -import textwrap import traceback +import warnings from argparse import Namespace -from multiprocessing import Process -from multiprocessing.connection import Connection, Pipe -from pathlib import Path from shutil import get_terminal_size -from typing import Optional +from typing import Any, Dict, List, Optional, Set import npyscreen import torch from npyscreen import widget from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.install.model_install_backend import InstallSelections, ModelInstall, SchedulerPredictionType -from invokeai.backend.model_management import ModelManager, ModelType +from invokeai.app.services.model_install import ModelInstallServiceBase +from invokeai.backend.install.install_helper import InstallHelper, InstallSelections, UnifiedModelInfo +from invokeai.backend.model_manager import ModelType from invokeai.backend.util import choose_precision, choose_torch_device from invokeai.backend.util.logging import InvokeAILogger from invokeai.frontend.install.widgets import ( MIN_COLS, MIN_LINES, - BufferBox, CenteredTitleText, CyclingForm, MultiSelectColumns, SingleSelectColumns, TextBox, WindowTooSmallException, - select_stable_diffusion_config_file, set_min_terminal_size, ) +warnings.filterwarnings("ignore", category=UserWarning) # noqa: E402 config = InvokeAIAppConfig.get_config() -logger = InvokeAILogger.get_logger() +logger = InvokeAILogger.get_logger("ModelInstallService") +logger.setLevel("WARNING") +# logger.setLevel('DEBUG') # build a table mapping all non-printable characters to None # for stripping control characters @@ -58,44 +56,42 @@ def make_printable(s: str) -> str: - """Replace non-printable characters in a string""" + """Replace non-printable characters in a string.""" return s.translate(NOPRINT_TRANS_TABLE) class addModelsForm(CyclingForm, npyscreen.FormMultiPage): + """Main form for interactive TUI.""" + # for responsive resizing set to False, but this seems to cause a crash! FIX_MINIMUM_SIZE_WHEN_CREATED = True # for persistence current_tab = 0 - def __init__(self, parentApp, name, multipage=False, *args, **keywords): + def __init__(self, parentApp: npyscreen.NPSAppManaged, name: str, multipage: bool = False, **keywords: Any): self.multipage = multipage self.subprocess = None - super().__init__(parentApp=parentApp, name=name, *args, **keywords) # noqa: B026 # TODO: maybe this is bad? + super().__init__(parentApp=parentApp, name=name, **keywords) - def create(self): + def create(self) -> None: + self.installer = self.parentApp.install_helper.installer + self.model_labels = self._get_model_labels() self.keypress_timeout = 10 self.counter = 0 self.subprocess_connection = None - if not config.model_conf_path.exists(): - with open(config.model_conf_path, "w") as file: - print("# InvokeAI model configuration file", file=file) - self.installer = ModelInstall(config) - self.all_models = self.installer.all_models() - self.starter_models = self.installer.starter_models() - self.model_labels = self._get_model_labels() window_width, window_height = get_terminal_size() - self.nextrely -= 1 + # npyscreen has no typing hints + self.nextrely -= 1 # type: ignore self.add_widget_intelligent( npyscreen.FixedText, value="Use ctrl-N and ctrl-P to move to the ext and

revious fields. Cursor keys navigate, and selects.", editable=False, color="CAUTION", ) - self.nextrely += 1 + self.nextrely += 1 # type: ignore self.tabs = self.add_widget_intelligent( SingleSelectColumns, values=[ @@ -115,9 +111,9 @@ def create(self): ) self.tabs.on_changed = self._toggle_tables - top_of_table = self.nextrely + top_of_table = self.nextrely # type: ignore self.starter_pipelines = self.add_starter_pipelines() - bottom_of_table = self.nextrely + bottom_of_table = self.nextrely # type: ignore self.nextrely = top_of_table self.pipeline_models = self.add_pipeline_widgets( @@ -162,15 +158,7 @@ def create(self): self.nextrely = bottom_of_table + 1 - self.monitor = self.add_widget_intelligent( - BufferBox, - name="Log Messages", - editable=False, - max_height=6, - ) - self.nextrely += 1 - done_label = "APPLY CHANGES" back_label = "BACK" cancel_label = "CANCEL" current_position = self.nextrely @@ -186,14 +174,8 @@ def create(self): npyscreen.ButtonPress, name=cancel_label, when_pressed_function=self.on_cancel ) self.nextrely = current_position - self.ok_button = self.add_widget_intelligent( - npyscreen.ButtonPress, - name=done_label, - relx=(window_width - len(done_label)) // 2, - when_pressed_function=self.on_execute, - ) - label = "APPLY CHANGES & EXIT" + label = "APPLY CHANGES" self.nextrely = current_position self.done = self.add_widget_intelligent( npyscreen.ButtonPress, @@ -210,17 +192,16 @@ def create(self): ############# diffusers tab ########## def add_starter_pipelines(self) -> dict[str, npyscreen.widget]: """Add widgets responsible for selecting diffusers models""" - widgets = {} - models = self.all_models - starters = self.starter_models - starter_model_labels = self.model_labels + widgets: Dict[str, npyscreen.widget] = {} - self.installed_models = sorted([x for x in starters if models[x].installed]) + all_models = self.all_models # master dict of all models, indexed by key + model_list = [x for x in self.starter_models if all_models[x].type in ["main", "vae"]] + model_labels = [self.model_labels[x] for x in model_list] widgets.update( label1=self.add_widget_intelligent( CenteredTitleText, - name="Select from a starter set of Stable Diffusion models from HuggingFace.", + name="Select from a starter set of Stable Diffusion models from HuggingFace and Civitae.", editable=False, labelColor="CAUTION", ) @@ -230,23 +211,24 @@ def add_starter_pipelines(self) -> dict[str, npyscreen.widget]: # if user has already installed some initial models, then don't patronize them # by showing more recommendations show_recommended = len(self.installed_models) == 0 - keys = [x for x in models.keys() if x in starters] + + checked = [ + model_list.index(x) + for x in model_list + if (show_recommended and all_models[x].recommended) or all_models[x].installed + ] widgets.update( models_selected=self.add_widget_intelligent( MultiSelectColumns, columns=1, name="Install Starter Models", - values=[starter_model_labels[x] for x in keys], - value=[ - keys.index(x) - for x in keys - if (show_recommended and models[x].recommended) or (x in self.installed_models) - ], - max_height=len(starters) + 1, + values=model_labels, + value=checked, + max_height=len(model_list) + 1, relx=4, scroll_exit=True, ), - models=keys, + models=model_list, ) self.nextrely += 1 @@ -257,14 +239,18 @@ def add_model_widgets( self, model_type: ModelType, window_width: int = 120, - install_prompt: str = None, - exclude: set = None, + install_prompt: Optional[str] = None, + exclude: Optional[Set[str]] = None, ) -> dict[str, npyscreen.widget]: """Generic code to create model selection widgets""" if exclude is None: exclude = set() - widgets = {} - model_list = [x for x in self.all_models if self.all_models[x].model_type == model_type and x not in exclude] + widgets: Dict[str, npyscreen.widget] = {} + all_models = self.all_models + model_list = sorted( + [x for x in all_models if all_models[x].type == model_type and x not in exclude], + key=lambda x: all_models[x].name or "", + ) model_labels = [self.model_labels[x] for x in model_list] show_recommended = len(self.installed_models) == 0 @@ -300,7 +286,7 @@ def add_model_widgets( value=[ model_list.index(x) for x in model_list - if (show_recommended and self.all_models[x].recommended) or self.all_models[x].installed + if (show_recommended and all_models[x].recommended) or all_models[x].installed ], max_height=len(model_list) // columns + 1, relx=4, @@ -324,7 +310,7 @@ def add_model_widgets( download_ids=self.add_widget_intelligent( TextBox, name="Additional URLs, or HuggingFace repo_ids to install (Space separated. Use shift-control-V to paste):", - max_height=4, + max_height=6, scroll_exit=True, editable=True, ) @@ -349,13 +335,13 @@ def add_pipeline_widgets( return widgets - def resize(self): + def resize(self) -> None: super().resize() if s := self.starter_pipelines.get("models_selected"): - keys = [x for x in self.all_models.keys() if x in self.starter_models] - s.values = [self.model_labels[x] for x in keys] + if model_list := self.starter_pipelines.get("models"): + s.values = [self.model_labels[x] for x in model_list] - def _toggle_tables(self, value=None): + def _toggle_tables(self, value: List[int]) -> None: selected_tab = value[0] widgets = [ self.starter_pipelines, @@ -385,17 +371,18 @@ def _toggle_tables(self, value=None): self.display() def _get_model_labels(self) -> dict[str, str]: + """Return a list of trimmed labels for all models.""" window_width, window_height = get_terminal_size() checkbox_width = 4 spacing_width = 2 + result = {} models = self.all_models - label_width = max([len(models[x].name) for x in models]) + label_width = max([len(models[x].name or "") for x in self.starter_models]) description_width = window_width - label_width - checkbox_width - spacing_width - result = {} - for x in models.keys(): - description = models[x].description + for key in self.all_models: + description = models[key].description description = ( description[0 : description_width - 3] + "..." if description and len(description) > description_width @@ -403,7 +390,8 @@ def _get_model_labels(self) -> dict[str, str]: if description else "" ) - result[x] = f"%-{label_width}s %s" % (models[x].name, description) + result[key] = f"%-{label_width}s %s" % (models[key].name, description) + return result def _get_columns(self) -> int: @@ -413,50 +401,40 @@ def _get_columns(self) -> int: def confirm_deletions(self, selections: InstallSelections) -> bool: remove_models = selections.remove_models - if len(remove_models) > 0: - mods = "\n".join([ModelManager.parse_key(x)[0] for x in remove_models]) - return npyscreen.notify_ok_cancel( + if remove_models: + model_names = [self.all_models[x].name or "" for x in remove_models] + mods = "\n".join(model_names) + is_ok = npyscreen.notify_ok_cancel( f"These unchecked models will be deleted from disk. Continue?\n---------\n{mods}" ) + assert isinstance(is_ok, bool) # npyscreen doesn't have return type annotations + return is_ok else: return True - def on_execute(self): - self.marshall_arguments() - app = self.parentApp - if not self.confirm_deletions(app.install_selections): - return + @property + def all_models(self) -> Dict[str, UnifiedModelInfo]: + # npyscreen doesn't having typing hints + return self.parentApp.install_helper.all_models # type: ignore - self.monitor.entry_widget.buffer(["Processing..."], scroll_end=True) - self.ok_button.hidden = True - self.display() + @property + def starter_models(self) -> List[str]: + return self.parentApp.install_helper._starter_models # type: ignore - # TO DO: Spawn a worker thread, not a subprocess - parent_conn, child_conn = Pipe() - p = Process( - target=process_and_execute, - kwargs={ - "opt": app.program_opts, - "selections": app.install_selections, - "conn_out": child_conn, - }, - ) - p.start() - child_conn.close() - self.subprocess_connection = parent_conn - self.subprocess = p - app.install_selections = InstallSelections() + @property + def installed_models(self) -> List[str]: + return self.parentApp.install_helper._installed_models # type: ignore - def on_back(self): + def on_back(self) -> None: self.parentApp.switchFormPrevious() self.editing = False - def on_cancel(self): + def on_cancel(self) -> None: self.parentApp.setNextForm(None) self.parentApp.user_cancelled = True self.editing = False - def on_done(self): + def on_done(self) -> None: self.marshall_arguments() if not self.confirm_deletions(self.parentApp.install_selections): return @@ -464,77 +442,7 @@ def on_done(self): self.parentApp.user_cancelled = False self.editing = False - ########## This routine monitors the child process that is performing model installation and removal ##### - def while_waiting(self): - """Called during idle periods. Main task is to update the Log Messages box with messages - from the child process that does the actual installation/removal""" - c = self.subprocess_connection - if not c: - return - - monitor_widget = self.monitor.entry_widget - while c.poll(): - try: - data = c.recv_bytes().decode("utf-8") - data.strip("\n") - - # processing child is requesting user input to select the - # right configuration file - if data.startswith("*need v2 config"): - _, model_path, *_ = data.split(":", 2) - self._return_v2_config(model_path) - - # processing child is done - elif data == "*done*": - self._close_subprocess_and_regenerate_form() - break - - # update the log message box - else: - data = make_printable(data) - data = data.replace("[A", "") - monitor_widget.buffer( - textwrap.wrap( - data, - width=monitor_widget.width, - subsequent_indent=" ", - ), - scroll_end=True, - ) - self.display() - except (EOFError, OSError): - self.subprocess_connection = None - - def _return_v2_config(self, model_path: str): - c = self.subprocess_connection - model_name = Path(model_path).name - message = select_stable_diffusion_config_file(model_name=model_name) - c.send_bytes(message.encode("utf-8")) - - def _close_subprocess_and_regenerate_form(self): - app = self.parentApp - self.subprocess_connection.close() - self.subprocess_connection = None - self.monitor.entry_widget.buffer(["** Action Complete **"]) - self.display() - - # rebuild the form, saving and restoring some of the fields that need to be preserved. - saved_messages = self.monitor.entry_widget.values - - app.main_form = app.addForm( - "MAIN", - addModelsForm, - name="Install Stable Diffusion Models", - multipage=self.multipage, - ) - app.switchForm("MAIN") - - app.main_form.monitor.entry_widget.values = saved_messages - app.main_form.monitor.entry_widget.buffer([""], scroll_end=True) - # app.main_form.pipeline_models['autoload_directory'].value = autoload_dir - # app.main_form.pipeline_models['autoscan_on_startup'].value = autoscan - - def marshall_arguments(self): + def marshall_arguments(self) -> None: """ Assemble arguments and store as attributes of the application: .starter_models: dict of model names to install from INITIAL_CONFIGURE.yaml @@ -564,46 +472,24 @@ def marshall_arguments(self): models_to_install = [x for x in selected if not self.all_models[x].installed] models_to_remove = [x for x in section["models"] if x not in selected and self.all_models[x].installed] selections.remove_models.extend(models_to_remove) - selections.install_models.extend( - all_models[x].path or all_models[x].repo_id - for x in models_to_install - if all_models[x].path or all_models[x].repo_id - ) + selections.install_models.extend([all_models[x] for x in models_to_install]) # models located in the 'download_ids" section for section in ui_sections: if downloads := section.get("download_ids"): - selections.install_models.extend(downloads.value.split()) - - # NOT NEEDED - DONE IN BACKEND NOW - # # special case for the ipadapter_models. If any of the adapters are - # # chosen, then we add the corresponding encoder(s) to the install list. - # section = self.ipadapter_models - # if section.get("models_selected"): - # selected_adapters = [ - # self.all_models[section["models"][x]].name for x in section.get("models_selected").value - # ] - # encoders = [] - # if any(["sdxl" in x for x in selected_adapters]): - # encoders.append("ip_adapter_sdxl_image_encoder") - # if any(["sd15" in x for x in selected_adapters]): - # encoders.append("ip_adapter_sd_image_encoder") - # for encoder in encoders: - # key = f"any/clip_vision/{encoder}" - # repo_id = f"InvokeAI/{encoder}" - # if key not in self.all_models: - # selections.install_models.append(repo_id) - - -class AddModelApplication(npyscreen.NPSAppManaged): - def __init__(self, opt): + models = [UnifiedModelInfo(source=x) for x in downloads.value.split()] + selections.install_models.extend(models) + + +class AddModelApplication(npyscreen.NPSAppManaged): # type: ignore + def __init__(self, opt: Namespace, install_helper: InstallHelper): super().__init__() self.program_opts = opt self.user_cancelled = False - # self.autoload_pending = True self.install_selections = InstallSelections() + self.install_helper = install_helper - def onStart(self): + def onStart(self) -> None: npyscreen.setTheme(npyscreen.Themes.DefaultTheme) self.main_form = self.addForm( "MAIN", @@ -613,138 +499,62 @@ def onStart(self): ) -class StderrToMessage: - def __init__(self, connection: Connection): - self.connection = connection - - def write(self, data: str): - self.connection.send_bytes(data.encode("utf-8")) - - def flush(self): - pass +def list_models(installer: ModelInstallServiceBase, model_type: ModelType): + """Print out all models of type model_type.""" + models = installer.record_store.search_by_attr(model_type=model_type) + print(f"Installed models of type `{model_type}`:") + for model in models: + path = (config.models_path / model.path).resolve() + print(f"{model.name:40}{model.base.value:5}{model.type.value:8}{model.format.value:12}{path}") # -------------------------------------------------------- -def ask_user_for_prediction_type(model_path: Path, tui_conn: Connection = None) -> SchedulerPredictionType: - if tui_conn: - logger.debug("Waiting for user response...") - return _ask_user_for_pt_tui(model_path, tui_conn) - else: - return _ask_user_for_pt_cmdline(model_path) - - -def _ask_user_for_pt_cmdline(model_path: Path) -> Optional[SchedulerPredictionType]: - choices = [SchedulerPredictionType.Epsilon, SchedulerPredictionType.VPrediction, None] - print( - f""" -Please select the scheduler prediction type of the checkpoint named {model_path.name}: -[1] "epsilon" - most v1.5 models and v2 models trained on 512 pixel images -[2] "vprediction" - v2 models trained on 768 pixel images and a few v1.5 models -[3] Accept the best guess; you can fix it in the Web UI later -""" - ) - choice = None - ok = False - while not ok: - try: - choice = input("select [3]> ").strip() - if not choice: - return None - choice = choices[int(choice) - 1] - ok = True - except (ValueError, IndexError): - print(f"{choice} is not a valid choice") - except EOFError: - return - return choice - - -def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection) -> SchedulerPredictionType: - tui_conn.send_bytes(f"*need v2 config for:{model_path}".encode("utf-8")) - # note that we don't do any status checking here - response = tui_conn.recv_bytes().decode("utf-8") - if response is None: - return None - elif response == "epsilon": - return SchedulerPredictionType.epsilon - elif response == "v": - return SchedulerPredictionType.VPrediction - elif response == "guess": - return None - else: - return None - - -# -------------------------------------------------------- -def process_and_execute( - opt: Namespace, - selections: InstallSelections, - conn_out: Connection = None, -): - # need to reinitialize config in subprocess - config = InvokeAIAppConfig.get_config() - args = ["--root", opt.root] if opt.root else [] - config.parse_args(args) - - # set up so that stderr is sent to conn_out - if conn_out: - translator = StderrToMessage(conn_out) - sys.stderr = translator - sys.stdout = translator - logger = InvokeAILogger.get_logger() - logger.handlers.clear() - logger.addHandler(logging.StreamHandler(translator)) - - installer = ModelInstall(config, prediction_type_helper=lambda x: ask_user_for_prediction_type(x, conn_out)) - installer.install(selections) - - if conn_out: - conn_out.send_bytes("*done*".encode("utf-8")) - conn_out.close() - - -# -------------------------------------------------------- -def select_and_download_models(opt: Namespace): +def select_and_download_models(opt: Namespace) -> None: + """Prompt user for install/delete selections and execute.""" precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device())) + # unsure how to avoid a typing complaint in the next line: config.precision is an enumerated Literal config.precision = precision - installer = ModelInstall(config, prediction_type_helper=ask_user_for_prediction_type) + install_helper = InstallHelper(config, logger) + installer = install_helper.installer + if opt.list_models: - installer.list_models(opt.list_models) + list_models(installer, opt.list_models) + elif opt.add or opt.delete: - selections = InstallSelections(install_models=opt.add or [], remove_models=opt.delete or []) - installer.install(selections) + selections = InstallSelections( + install_models=[UnifiedModelInfo(source=x) for x in (opt.add or [])], remove_models=opt.delete or [] + ) + install_helper.add_or_delete(selections) + elif opt.default_only: - selections = InstallSelections(install_models=installer.default_model()) - installer.install(selections) + default_model = install_helper.default_model() + assert default_model is not None + selections = InstallSelections(install_models=[default_model]) + install_helper.add_or_delete(selections) + elif opt.yes_to_all: - selections = InstallSelections(install_models=installer.recommended_models()) - installer.install(selections) + selections = InstallSelections(install_models=install_helper.recommended_models()) + install_helper.add_or_delete(selections) # this is where the TUI is called else: - # needed to support the probe() method running under a subprocess - torch.multiprocessing.set_start_method("spawn") - if not set_min_terminal_size(MIN_COLS, MIN_LINES): raise WindowTooSmallException( "Could not increase terminal size. Try running again with a larger window or smaller font size." ) - installApp = AddModelApplication(opt) + installApp = AddModelApplication(opt, install_helper) try: installApp.run() - except KeyboardInterrupt as e: - if hasattr(installApp, "main_form"): - if installApp.main_form.subprocess and installApp.main_form.subprocess.is_alive(): - logger.info("Terminating subprocesses") - installApp.main_form.subprocess.terminate() - installApp.main_form.subprocess = None - raise e - process_and_execute(opt, installApp.install_selections) + except KeyboardInterrupt: + print("Aborted...") + sys.exit(-1) + + install_helper.add_or_delete(installApp.install_selections) # ------------------------------------- -def main(): +def main() -> None: parser = argparse.ArgumentParser(description="InvokeAI model downloader") parser.add_argument( "--add", @@ -754,7 +564,7 @@ def main(): parser.add_argument( "--delete", nargs="*", - help="List of names of models to idelete", + help="List of names of models to delete. Use type:name to disambiguate, as in `controlnet:my_model`", ) parser.add_argument( "--full-precision", @@ -781,14 +591,6 @@ def main(): choices=[x.value for x in ModelType], help="list installed models", ) - parser.add_argument( - "--config_file", - "-c", - dest="config_file", - type=str, - default=None, - help="path to configuration file to create", - ) parser.add_argument( "--root_dir", dest="root", diff --git a/invokeai/frontend/install/model_install2.py b/invokeai/frontend/install/model_install2.py deleted file mode 100644 index 6eb480c8d9d..00000000000 --- a/invokeai/frontend/install/model_install2.py +++ /dev/null @@ -1,645 +0,0 @@ -#!/usr/bin/env python -# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein) -# Before running stable-diffusion on an internet-isolated machine, -# run this script from one with internet connectivity. The -# two machines must share a common .cache directory. - -""" -This is the npyscreen frontend to the model installation application. -It is currently named model_install2.py, but will ultimately replace model_install.py. -""" - -import argparse -import curses -import sys -import traceback -import warnings -from argparse import Namespace -from shutil import get_terminal_size -from typing import Any, Dict, List, Optional, Set - -import npyscreen -import torch -from npyscreen import widget - -from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.model_install import ModelInstallService -from invokeai.backend.install.install_helper import InstallHelper, InstallSelections, UnifiedModelInfo -from invokeai.backend.model_manager import ModelType -from invokeai.backend.util import choose_precision, choose_torch_device -from invokeai.backend.util.logging import InvokeAILogger -from invokeai.frontend.install.widgets import ( - MIN_COLS, - MIN_LINES, - CenteredTitleText, - CyclingForm, - MultiSelectColumns, - SingleSelectColumns, - TextBox, - WindowTooSmallException, - set_min_terminal_size, -) - -warnings.filterwarnings("ignore", category=UserWarning) # noqa: E402 -config = InvokeAIAppConfig.get_config() -logger = InvokeAILogger.get_logger("ModelInstallService") -logger.setLevel("WARNING") -# logger.setLevel('DEBUG') - -# build a table mapping all non-printable characters to None -# for stripping control characters -# from https://stackoverflow.com/questions/92438/stripping-non-printable-characters-from-a-string-in-python -NOPRINT_TRANS_TABLE = {i: None for i in range(0, sys.maxunicode + 1) if not chr(i).isprintable()} - -# maximum number of installed models we can display before overflowing vertically -MAX_OTHER_MODELS = 72 - - -def make_printable(s: str) -> str: - """Replace non-printable characters in a string.""" - return s.translate(NOPRINT_TRANS_TABLE) - - -class addModelsForm(CyclingForm, npyscreen.FormMultiPage): - """Main form for interactive TUI.""" - - # for responsive resizing set to False, but this seems to cause a crash! - FIX_MINIMUM_SIZE_WHEN_CREATED = True - - # for persistence - current_tab = 0 - - def __init__(self, parentApp: npyscreen.NPSAppManaged, name: str, multipage: bool = False, **keywords: Any): - self.multipage = multipage - self.subprocess = None - super().__init__(parentApp=parentApp, name=name, **keywords) - - def create(self) -> None: - self.installer = self.parentApp.install_helper.installer - self.model_labels = self._get_model_labels() - self.keypress_timeout = 10 - self.counter = 0 - self.subprocess_connection = None - - window_width, window_height = get_terminal_size() - - # npyscreen has no typing hints - self.nextrely -= 1 # type: ignore - self.add_widget_intelligent( - npyscreen.FixedText, - value="Use ctrl-N and ctrl-P to move to the ext and

revious fields. Cursor keys navigate, and selects.", - editable=False, - color="CAUTION", - ) - self.nextrely += 1 # type: ignore - self.tabs = self.add_widget_intelligent( - SingleSelectColumns, - values=[ - "STARTERS", - "MAINS", - "CONTROLNETS", - "T2I-ADAPTERS", - "IP-ADAPTERS", - "LORAS", - "TI EMBEDDINGS", - ], - value=[self.current_tab], - columns=7, - max_height=2, - relx=8, - scroll_exit=True, - ) - self.tabs.on_changed = self._toggle_tables - - top_of_table = self.nextrely # type: ignore - self.starter_pipelines = self.add_starter_pipelines() - bottom_of_table = self.nextrely # type: ignore - - self.nextrely = top_of_table - self.pipeline_models = self.add_pipeline_widgets( - model_type=ModelType.Main, window_width=window_width, exclude=self.starter_models - ) - # self.pipeline_models['autoload_pending'] = True - bottom_of_table = max(bottom_of_table, self.nextrely) - - self.nextrely = top_of_table - self.controlnet_models = self.add_model_widgets( - model_type=ModelType.ControlNet, - window_width=window_width, - ) - bottom_of_table = max(bottom_of_table, self.nextrely) - - self.nextrely = top_of_table - self.t2i_models = self.add_model_widgets( - model_type=ModelType.T2IAdapter, - window_width=window_width, - ) - bottom_of_table = max(bottom_of_table, self.nextrely) - self.nextrely = top_of_table - self.ipadapter_models = self.add_model_widgets( - model_type=ModelType.IPAdapter, - window_width=window_width, - ) - bottom_of_table = max(bottom_of_table, self.nextrely) - - self.nextrely = top_of_table - self.lora_models = self.add_model_widgets( - model_type=ModelType.Lora, - window_width=window_width, - ) - bottom_of_table = max(bottom_of_table, self.nextrely) - - self.nextrely = top_of_table - self.ti_models = self.add_model_widgets( - model_type=ModelType.TextualInversion, - window_width=window_width, - ) - bottom_of_table = max(bottom_of_table, self.nextrely) - - self.nextrely = bottom_of_table + 1 - - self.nextrely += 1 - back_label = "BACK" - cancel_label = "CANCEL" - current_position = self.nextrely - if self.multipage: - self.back_button = self.add_widget_intelligent( - npyscreen.ButtonPress, - name=back_label, - when_pressed_function=self.on_back, - ) - else: - self.nextrely = current_position - self.cancel_button = self.add_widget_intelligent( - npyscreen.ButtonPress, name=cancel_label, when_pressed_function=self.on_cancel - ) - self.nextrely = current_position - - label = "APPLY CHANGES" - self.nextrely = current_position - self.done = self.add_widget_intelligent( - npyscreen.ButtonPress, - name=label, - relx=window_width - len(label) - 15, - when_pressed_function=self.on_done, - ) - - # This restores the selected page on return from an installation - for _i in range(1, self.current_tab + 1): - self.tabs.h_cursor_line_down(1) - self._toggle_tables([self.current_tab]) - - ############# diffusers tab ########## - def add_starter_pipelines(self) -> dict[str, npyscreen.widget]: - """Add widgets responsible for selecting diffusers models""" - widgets: Dict[str, npyscreen.widget] = {} - - all_models = self.all_models # master dict of all models, indexed by key - model_list = [x for x in self.starter_models if all_models[x].type in ["main", "vae"]] - model_labels = [self.model_labels[x] for x in model_list] - - widgets.update( - label1=self.add_widget_intelligent( - CenteredTitleText, - name="Select from a starter set of Stable Diffusion models from HuggingFace and Civitae.", - editable=False, - labelColor="CAUTION", - ) - ) - - self.nextrely -= 1 - # if user has already installed some initial models, then don't patronize them - # by showing more recommendations - show_recommended = len(self.installed_models) == 0 - - checked = [ - model_list.index(x) - for x in model_list - if (show_recommended and all_models[x].recommended) or all_models[x].installed - ] - widgets.update( - models_selected=self.add_widget_intelligent( - MultiSelectColumns, - columns=1, - name="Install Starter Models", - values=model_labels, - value=checked, - max_height=len(model_list) + 1, - relx=4, - scroll_exit=True, - ), - models=model_list, - ) - - self.nextrely += 1 - return widgets - - ############# Add a set of model install widgets ######## - def add_model_widgets( - self, - model_type: ModelType, - window_width: int = 120, - install_prompt: Optional[str] = None, - exclude: Optional[Set[str]] = None, - ) -> dict[str, npyscreen.widget]: - """Generic code to create model selection widgets""" - if exclude is None: - exclude = set() - widgets: Dict[str, npyscreen.widget] = {} - all_models = self.all_models - model_list = sorted( - [x for x in all_models if all_models[x].type == model_type and x not in exclude], - key=lambda x: all_models[x].name or "", - ) - model_labels = [self.model_labels[x] for x in model_list] - - show_recommended = len(self.installed_models) == 0 - truncated = False - if len(model_list) > 0: - max_width = max([len(x) for x in model_labels]) - columns = window_width // (max_width + 8) # 8 characters for "[x] " and padding - columns = min(len(model_list), columns) or 1 - prompt = ( - install_prompt - or f"Select the desired {model_type.value.title()} models to install. Unchecked models will be purged from disk." - ) - - widgets.update( - label1=self.add_widget_intelligent( - CenteredTitleText, - name=prompt, - editable=False, - labelColor="CAUTION", - ) - ) - - if len(model_labels) > MAX_OTHER_MODELS: - model_labels = model_labels[0:MAX_OTHER_MODELS] - truncated = True - - widgets.update( - models_selected=self.add_widget_intelligent( - MultiSelectColumns, - columns=columns, - name=f"Install {model_type} Models", - values=model_labels, - value=[ - model_list.index(x) - for x in model_list - if (show_recommended and all_models[x].recommended) or all_models[x].installed - ], - max_height=len(model_list) // columns + 1, - relx=4, - scroll_exit=True, - ), - models=model_list, - ) - - if truncated: - widgets.update( - warning_message=self.add_widget_intelligent( - npyscreen.FixedText, - value=f"Too many models to display (max={MAX_OTHER_MODELS}). Some are not displayed.", - editable=False, - color="CAUTION", - ) - ) - - self.nextrely += 1 - widgets.update( - download_ids=self.add_widget_intelligent( - TextBox, - name="Additional URLs, or HuggingFace repo_ids to install (Space separated. Use shift-control-V to paste):", - max_height=6, - scroll_exit=True, - editable=True, - ) - ) - return widgets - - ### Tab for arbitrary diffusers widgets ### - def add_pipeline_widgets( - self, - model_type: ModelType = ModelType.Main, - window_width: int = 120, - **kwargs, - ) -> dict[str, npyscreen.widget]: - """Similar to add_model_widgets() but adds some additional widgets at the bottom - to support the autoload directory""" - widgets = self.add_model_widgets( - model_type=model_type, - window_width=window_width, - install_prompt=f"Installed {model_type.value.title()} models. Unchecked models in the InvokeAI root directory will be deleted. Enter URLs, paths or repo_ids to import.", - **kwargs, - ) - - return widgets - - def resize(self) -> None: - super().resize() - if s := self.starter_pipelines.get("models_selected"): - if model_list := self.starter_pipelines.get("models"): - s.values = [self.model_labels[x] for x in model_list] - - def _toggle_tables(self, value: List[int]) -> None: - selected_tab = value[0] - widgets = [ - self.starter_pipelines, - self.pipeline_models, - self.controlnet_models, - self.t2i_models, - self.ipadapter_models, - self.lora_models, - self.ti_models, - ] - - for group in widgets: - for _k, v in group.items(): - try: - v.hidden = True - v.editable = False - except Exception: - pass - for _k, v in widgets[selected_tab].items(): - try: - v.hidden = False - if not isinstance(v, (npyscreen.FixedText, npyscreen.TitleFixedText, CenteredTitleText)): - v.editable = True - except Exception: - pass - self.__class__.current_tab = selected_tab # for persistence - self.display() - - def _get_model_labels(self) -> dict[str, str]: - """Return a list of trimmed labels for all models.""" - window_width, window_height = get_terminal_size() - checkbox_width = 4 - spacing_width = 2 - result = {} - - models = self.all_models - label_width = max([len(models[x].name or "") for x in self.starter_models]) - description_width = window_width - label_width - checkbox_width - spacing_width - - for key in self.all_models: - description = models[key].description - description = ( - description[0 : description_width - 3] + "..." - if description and len(description) > description_width - else description - if description - else "" - ) - result[key] = f"%-{label_width}s %s" % (models[key].name, description) - - return result - - def _get_columns(self) -> int: - window_width, window_height = get_terminal_size() - cols = 4 if window_width > 240 else 3 if window_width > 160 else 2 if window_width > 80 else 1 - return min(cols, len(self.installed_models)) - - def confirm_deletions(self, selections: InstallSelections) -> bool: - remove_models = selections.remove_models - if remove_models: - model_names = [self.all_models[x].name or "" for x in remove_models] - mods = "\n".join(model_names) - is_ok = npyscreen.notify_ok_cancel( - f"These unchecked models will be deleted from disk. Continue?\n---------\n{mods}" - ) - assert isinstance(is_ok, bool) # npyscreen doesn't have return type annotations - return is_ok - else: - return True - - @property - def all_models(self) -> Dict[str, UnifiedModelInfo]: - # npyscreen doesn't having typing hints - return self.parentApp.install_helper.all_models # type: ignore - - @property - def starter_models(self) -> List[str]: - return self.parentApp.install_helper._starter_models # type: ignore - - @property - def installed_models(self) -> List[str]: - return self.parentApp.install_helper._installed_models # type: ignore - - def on_back(self) -> None: - self.parentApp.switchFormPrevious() - self.editing = False - - def on_cancel(self) -> None: - self.parentApp.setNextForm(None) - self.parentApp.user_cancelled = True - self.editing = False - - def on_done(self) -> None: - self.marshall_arguments() - if not self.confirm_deletions(self.parentApp.install_selections): - return - self.parentApp.setNextForm(None) - self.parentApp.user_cancelled = False - self.editing = False - - def marshall_arguments(self) -> None: - """ - Assemble arguments and store as attributes of the application: - .starter_models: dict of model names to install from INITIAL_CONFIGURE.yaml - True => Install - False => Remove - .scan_directory: Path to a directory of models to scan and import - .autoscan_on_startup: True if invokeai should scan and import at startup time - .import_model_paths: list of URLs, repo_ids and file paths to import - """ - selections = self.parentApp.install_selections - all_models = self.all_models - - # Defined models (in INITIAL_CONFIG.yaml or models.yaml) to add/remove - ui_sections = [ - self.starter_pipelines, - self.pipeline_models, - self.controlnet_models, - self.t2i_models, - self.ipadapter_models, - self.lora_models, - self.ti_models, - ] - for section in ui_sections: - if "models_selected" not in section: - continue - selected = {section["models"][x] for x in section["models_selected"].value} - models_to_install = [x for x in selected if not self.all_models[x].installed] - models_to_remove = [x for x in section["models"] if x not in selected and self.all_models[x].installed] - selections.remove_models.extend(models_to_remove) - selections.install_models.extend([all_models[x] for x in models_to_install]) - - # models located in the 'download_ids" section - for section in ui_sections: - if downloads := section.get("download_ids"): - models = [UnifiedModelInfo(source=x) for x in downloads.value.split()] - selections.install_models.extend(models) - - -class AddModelApplication(npyscreen.NPSAppManaged): # type: ignore - def __init__(self, opt: Namespace, install_helper: InstallHelper): - super().__init__() - self.program_opts = opt - self.user_cancelled = False - self.install_selections = InstallSelections() - self.install_helper = install_helper - - def onStart(self) -> None: - npyscreen.setTheme(npyscreen.Themes.DefaultTheme) - self.main_form = self.addForm( - "MAIN", - addModelsForm, - name="Install Stable Diffusion Models", - cycle_widgets=False, - ) - - -def list_models(installer: ModelInstallService, model_type: ModelType): - """Print out all models of type model_type.""" - models = installer.record_store.search_by_attr(model_type=model_type) - print(f"Installed models of type `{model_type}`:") - for model in models: - path = (config.models_path / model.path).resolve() - print(f"{model.name:40}{model.base.value:14}{path}") - - -# -------------------------------------------------------- -def select_and_download_models(opt: Namespace) -> None: - """Prompt user for install/delete selections and execute.""" - precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device())) - # unsure how to avoid a typing complaint in the next line: config.precision is an enumerated Literal - config.precision = precision # type: ignore - install_helper = InstallHelper(config, logger) - installer = install_helper.installer - - if opt.list_models: - list_models(installer, opt.list_models) - - elif opt.add or opt.delete: - selections = InstallSelections( - install_models=[UnifiedModelInfo(source=x) for x in (opt.add or [])], remove_models=opt.delete or [] - ) - install_helper.add_or_delete(selections) - - elif opt.default_only: - selections = InstallSelections(install_models=[install_helper.default_model()]) - install_helper.add_or_delete(selections) - - elif opt.yes_to_all: - selections = InstallSelections(install_models=install_helper.recommended_models()) - install_helper.add_or_delete(selections) - - # this is where the TUI is called - else: - if not set_min_terminal_size(MIN_COLS, MIN_LINES): - raise WindowTooSmallException( - "Could not increase terminal size. Try running again with a larger window or smaller font size." - ) - - installApp = AddModelApplication(opt, install_helper) - try: - installApp.run() - except KeyboardInterrupt: - print("Aborted...") - sys.exit(-1) - - install_helper.add_or_delete(installApp.install_selections) - - -# ------------------------------------- -def main() -> None: - parser = argparse.ArgumentParser(description="InvokeAI model downloader") - parser.add_argument( - "--add", - nargs="*", - help="List of URLs, local paths or repo_ids of models to install", - ) - parser.add_argument( - "--delete", - nargs="*", - help="List of names of models to delete. Use type:name to disambiguate, as in `controlnet:my_model`", - ) - parser.add_argument( - "--full-precision", - dest="full_precision", - action=argparse.BooleanOptionalAction, - type=bool, - default=False, - help="use 32-bit weights instead of faster 16-bit weights", - ) - parser.add_argument( - "--yes", - "-y", - dest="yes_to_all", - action="store_true", - help='answer "yes" to all prompts', - ) - parser.add_argument( - "--default_only", - action="store_true", - help="Only install the default model", - ) - parser.add_argument( - "--list-models", - choices=[x.value for x in ModelType], - help="list installed models", - ) - parser.add_argument( - "--root_dir", - dest="root", - type=str, - default=None, - help="path to root of install directory", - ) - opt = parser.parse_args() - - invoke_args = [] - if opt.root: - invoke_args.extend(["--root", opt.root]) - if opt.full_precision: - invoke_args.extend(["--precision", "float32"]) - config.parse_args(invoke_args) - logger = InvokeAILogger().get_logger(config=config) - - if not config.model_conf_path.exists(): - logger.info("Your InvokeAI root directory is not set up. Calling invokeai-configure.") - from invokeai.frontend.install.invokeai_configure import invokeai_configure - - invokeai_configure() - sys.exit(0) - - try: - select_and_download_models(opt) - except AssertionError as e: - logger.error(e) - sys.exit(-1) - except KeyboardInterrupt: - curses.nocbreak() - curses.echo() - curses.endwin() - logger.info("Goodbye! Come back soon.") - except WindowTooSmallException as e: - logger.error(str(e)) - except widget.NotEnoughSpaceForWidget as e: - if str(e).startswith("Height of 1 allocated"): - logger.error("Insufficient vertical space for the interface. Please make your window taller and try again") - input("Press any key to continue...") - except Exception as e: - if str(e).startswith("addwstr"): - logger.error( - "Insufficient horizontal space for the interface. Please make your window wider and try again." - ) - else: - print(f"An exception has occurred: {str(e)} Details:") - print(traceback.format_exc(), file=sys.stderr) - input("Press any key to continue...") - - -# ------------------------------------- -if __name__ == "__main__": - main() diff --git a/invokeai/frontend/install/widgets.py b/invokeai/frontend/install/widgets.py index 5905ae29dab..4dbc6349a0b 100644 --- a/invokeai/frontend/install/widgets.py +++ b/invokeai/frontend/install/widgets.py @@ -267,6 +267,17 @@ def h_select(self, ch): self.on_changed(self.value) +class CheckboxWithChanged(npyscreen.Checkbox): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.on_changed = None + + def whenToggled(self): + super().whenToggled() + if self.on_changed: + self.on_changed(self.value) + + class SingleSelectColumnsSimple(SelectColumnBase, SingleSelectWithChanged): """Row of radio buttons. Spacebar to select.""" diff --git a/invokeai/frontend/merge/merge_diffusers.py b/invokeai/frontend/merge/merge_diffusers.py index 92b98b52f96..5484040674d 100644 --- a/invokeai/frontend/merge/merge_diffusers.py +++ b/invokeai/frontend/merge/merge_diffusers.py @@ -6,20 +6,40 @@ """ import argparse import curses +import re import sys from argparse import Namespace from pathlib import Path -from typing import List +from typing import List, Optional, Tuple import npyscreen from npyscreen import widget -import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_management import BaseModelType, ModelManager, ModelMerger, ModelType +from invokeai.app.services.download import DownloadQueueService +from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage +from invokeai.app.services.model_install import ModelInstallService +from invokeai.app.services.model_metadata import ModelMetadataStoreSQL +from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL +from invokeai.app.services.shared.sqlite.sqlite_util import init_db +from invokeai.backend.model_manager import ( + BaseModelType, + ModelFormat, + ModelType, + ModelVariantType, +) +from invokeai.backend.model_manager.merge import ModelMerger +from invokeai.backend.util.logging import InvokeAILogger from invokeai.frontend.install.widgets import FloatTitleSlider, SingleSelectColumns, TextBox config = InvokeAIAppConfig.get_config() +logger = InvokeAILogger.get_logger() + +BASE_TYPES = [ + (BaseModelType.StableDiffusion1, "Models Built on SD-1.x"), + (BaseModelType.StableDiffusion2, "Models Built on SD-2.x"), + (BaseModelType.StableDiffusionXL, "Models Built on SDXL"), +] def _parse_args() -> Namespace: @@ -48,7 +68,7 @@ def _parse_args() -> Namespace: parser.add_argument( "--base_model", type=str, - choices=[x.value for x in BaseModelType], + choices=[x[0].value for x in BASE_TYPES], help="The base model shared by the models to be merged", ) parser.add_argument( @@ -98,17 +118,17 @@ def __init__(self, parentApp, name): super().__init__(parentApp, name) @property - def model_manager(self): - return self.parentApp.model_manager + def record_store(self): + return self.parentApp.record_store def afterEditing(self): self.parentApp.setNextForm(None) def create(self): window_height, window_width = curses.initscr().getmaxyx() - - self.model_names = self.get_model_names() self.current_base = 0 + self.models = self.get_models(BASE_TYPES[self.current_base][0]) + self.model_names = [x[1] for x in self.models] max_width = max([len(x) for x in self.model_names]) max_width += 6 horizontal_layout = max_width * 3 < window_width @@ -128,11 +148,7 @@ def create(self): self.nextrely += 1 self.base_select = self.add_widget_intelligent( SingleSelectColumns, - values=[ - "Models Built on SD-1.x", - "Models Built on SD-2.x", - "Models Built on SDXL", - ], + values=[x[1] for x in BASE_TYPES], value=[self.current_base], columns=4, max_height=2, @@ -263,21 +279,20 @@ def on_cancel(self): sys.exit(0) def marshall_arguments(self) -> dict: - model_names = self.model_names + model_keys = [x[0] for x in self.models] models = [ - model_names[self.model1.value[0]], - model_names[self.model2.value[0]], + model_keys[self.model1.value[0]], + model_keys[self.model2.value[0]], ] if self.model3.value[0] > 0: - models.append(model_names[self.model3.value[0] - 1]) + models.append(model_keys[self.model3.value[0] - 1]) interp = "add_difference" else: interp = self.interpolations[self.merge_method.value[0]] - bases = ["sd-1", "sd-2", "sdxl"] args = { - "model_names": models, - "base_model": BaseModelType(bases[self.base_select.value[0]]), + "model_keys": models, + "base_model": tuple(BaseModelType)[self.base_select.value[0]], "alpha": self.alpha.value, "interp": interp, "force": self.force.value, @@ -311,18 +326,18 @@ def validate_field_values(self) -> bool: else: return True - def get_model_names(self, base_model: BaseModelType = BaseModelType.StableDiffusion1) -> List[str]: - model_names = [ - info["model_name"] - for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model) - if info["model_format"] == "diffusers" + def get_models(self, base_model: Optional[BaseModelType] = None) -> List[Tuple[str, str]]: # key to name + models = [ + (x.key, x.name) + for x in self.record_store.search_by_attr(model_type=ModelType.Main, base_model=base_model) + if x.format == ModelFormat("diffusers") and x.variant == ModelVariantType("normal") ] - return sorted(model_names) + return sorted(models, key=lambda x: x[1]) - def _populate_models(self, value=None): - bases = ["sd-1", "sd-2", "sdxl"] - base_model = BaseModelType(bases[value[0]]) - self.model_names = self.get_model_names(base_model) + def _populate_models(self, value: List[int]): + base_model = BASE_TYPES[value[0]][0] + self.models = self.get_models(base_model) + self.model_names = [x[1] for x in self.models] models_plus_none = self.model_names.copy() models_plus_none.insert(0, "None") @@ -334,24 +349,24 @@ def _populate_models(self, value=None): class Mergeapp(npyscreen.NPSAppManaged): - def __init__(self, model_manager: ModelManager): + def __init__(self, record_store: ModelRecordServiceBase): super().__init__() - self.model_manager = model_manager + self.record_store = record_store def onStart(self): npyscreen.setTheme(npyscreen.Themes.ElegantTheme) self.main = self.addForm("MAIN", mergeModelsForm, name="Merge Models Settings") -def run_gui(args: Namespace): - model_manager = ModelManager(config.model_conf_path) - mergeapp = Mergeapp(model_manager) +def run_gui(args: Namespace) -> None: + record_store: ModelRecordServiceBase = get_config_store() + mergeapp = Mergeapp(record_store) mergeapp.run() - args = mergeapp.merge_arguments - merger = ModelMerger(model_manager) + merger = get_model_merger(record_store) merger.merge_diffusion_models_and_save(**args) - logger.info(f'Models merged into new model: "{args["merged_model_name"]}".') + merged_model_name = args["merged_model_name"] + logger.info(f'Models merged into new model: "{merged_model_name}".') def run_cli(args: Namespace): @@ -364,20 +379,54 @@ def run_cli(args: Namespace): args.merged_model_name = "+".join(args.model_names) logger.info(f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"') - model_manager = ModelManager(config.model_conf_path) + record_store: ModelRecordServiceBase = get_config_store() assert ( - not model_manager.model_exists(args.merged_model_name, args.base_model, ModelType.Main) or args.clobber + len(record_store.search_by_attr(args.merged_model_name, args.base_model, ModelType.Main)) == 0 or args.clobber ), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.' - merger = ModelMerger(model_manager) - merger.merge_diffusion_models_and_save(**vars(args)) + merger = get_model_merger(record_store) + model_keys = [] + for name in args.model_names: + if len(name) == 32 and re.match(r"^[0-9a-f]$", name): + model_keys.append(name) + else: + models = record_store.search_by_attr( + model_name=name, model_type=ModelType.Main, base_model=BaseModelType(args.base_model) + ) + assert len(models) > 0, f"{name}: Unknown model" + assert len(models) < 2, f"{name}: More than one model by this name. Please specify the model key instead." + model_keys.append(models[0].key) + + merger.merge_diffusion_models_and_save( + alpha=args.alpha, + model_keys=model_keys, + merged_model_name=args.merged_model_name, + interp=args.interp, + force=args.force, + ) logger.info(f'Models merged into new model: "{args.merged_model_name}".') +def get_config_store() -> ModelRecordServiceSQL: + output_path = config.output_path + assert output_path is not None + image_files = DiskImageFileStorage(output_path / "images") + db = init_db(config=config, logger=InvokeAILogger.get_logger(), image_files=image_files) + return ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db)) + + +def get_model_merger(record_store: ModelRecordServiceBase) -> ModelMerger: + installer = ModelInstallService(app_config=config, record_store=record_store, download_queue=DownloadQueueService()) + installer.start() + return ModelMerger(installer) + + def main(): args = _parse_args() if args.root_dir: config.parse_args(["--root", str(args.root_dir)]) + else: + config.parse_args([]) try: if args.front_end: diff --git a/invokeai/frontend/merge/merge_diffusers2.py b/invokeai/frontend/merge/merge_diffusers2.py deleted file mode 100644 index b365198f879..00000000000 --- a/invokeai/frontend/merge/merge_diffusers2.py +++ /dev/null @@ -1,438 +0,0 @@ -""" -invokeai.frontend.merge exports a single function called merge_diffusion_models(). - -It merges 2-3 models together and create a new InvokeAI-registered diffusion model. - -Copyright (c) 2023-24 Lincoln Stein and the InvokeAI Development Team -""" -import argparse -import curses -import re -import sys -from argparse import Namespace -from pathlib import Path -from typing import List, Optional, Tuple - -import npyscreen -from npyscreen import widget - -import invokeai.backend.util.logging as logger -from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.model_install import ModelInstallServiceBase -from invokeai.app.services.model_records import ModelRecordServiceBase -from invokeai.backend.install.install_helper import initialize_installer -from invokeai.backend.model_manager import ( - BaseModelType, - ModelFormat, - ModelType, - ModelVariantType, -) -from invokeai.backend.model_manager.merge import ModelMerger -from invokeai.frontend.install.widgets import FloatTitleSlider, SingleSelectColumns, TextBox - -config = InvokeAIAppConfig.get_config() - -BASE_TYPES = [ - (BaseModelType.StableDiffusion1, "Models Built on SD-1.x"), - (BaseModelType.StableDiffusion2, "Models Built on SD-2.x"), - (BaseModelType.StableDiffusionXL, "Models Built on SDXL"), -] - - -def _parse_args() -> Namespace: - parser = argparse.ArgumentParser(description="InvokeAI model merging") - parser.add_argument( - "--root_dir", - type=Path, - default=config.root, - help="Path to the invokeai runtime directory", - ) - parser.add_argument( - "--front_end", - "--gui", - dest="front_end", - action="store_true", - default=False, - help="Activate the text-based graphical front end for collecting parameters. Aside from --root_dir, other parameters will be ignored.", - ) - parser.add_argument( - "--models", - dest="model_names", - type=str, - nargs="+", - help="Two to three model names to be merged", - ) - parser.add_argument( - "--base_model", - type=str, - choices=[x[0].value for x in BASE_TYPES], - help="The base model shared by the models to be merged", - ) - parser.add_argument( - "--merged_model_name", - "--destination", - dest="merged_model_name", - type=str, - help="Name of the output model. If not specified, will be the concatenation of the input model names.", - ) - parser.add_argument( - "--alpha", - type=float, - default=0.5, - help="The interpolation parameter, ranging from 0 to 1. It affects the ratio in which the checkpoints are merged. Higher values give more weight to the 2d and 3d models", - ) - parser.add_argument( - "--interpolation", - dest="interp", - type=str, - choices=["weighted_sum", "sigmoid", "inv_sigmoid", "add_difference"], - default="weighted_sum", - help='Interpolation method to use. If three models are present, only "add_difference" will work.', - ) - parser.add_argument( - "--force", - action="store_true", - help="Try to merge models even if they are incompatible with each other", - ) - parser.add_argument( - "--clobber", - "--overwrite", - dest="clobber", - action="store_true", - help="Overwrite the merged model if --merged_model_name already exists", - ) - return parser.parse_args() - - -# ------------------------- GUI HERE ------------------------- -class mergeModelsForm(npyscreen.FormMultiPageAction): - interpolations = ["weighted_sum", "sigmoid", "inv_sigmoid"] - - def __init__(self, parentApp, name): - self.parentApp = parentApp - self.ALLOW_RESIZE = True - self.FIX_MINIMUM_SIZE_WHEN_CREATED = False - super().__init__(parentApp, name) - - @property - def model_record_store(self) -> ModelRecordServiceBase: - installer: ModelInstallServiceBase = self.parentApp.installer - return installer.record_store - - def afterEditing(self) -> None: - self.parentApp.setNextForm(None) - - def create(self) -> None: - window_height, window_width = curses.initscr().getmaxyx() - self.current_base = 0 - self.models = self.get_models(BASE_TYPES[self.current_base][0]) - self.model_names = [x[1] for x in self.models] - max_width = max([len(x) for x in self.model_names]) - max_width += 6 - horizontal_layout = max_width * 3 < window_width - - self.add_widget_intelligent( - npyscreen.FixedText, - color="CONTROL", - value="Select two models to merge and optionally a third.", - editable=False, - ) - self.add_widget_intelligent( - npyscreen.FixedText, - color="CONTROL", - value="Use up and down arrows to move, to select an item, and to move from one field to the next.", - editable=False, - ) - self.nextrely += 1 - self.base_select = self.add_widget_intelligent( - SingleSelectColumns, - values=[x[1] for x in BASE_TYPES], - value=[self.current_base], - columns=4, - max_height=2, - relx=8, - scroll_exit=True, - ) - self.base_select.on_changed = self._populate_models - self.add_widget_intelligent( - npyscreen.FixedText, - value="MODEL 1", - color="GOOD", - editable=False, - rely=6 if horizontal_layout else None, - ) - self.model1 = self.add_widget_intelligent( - npyscreen.SelectOne, - values=self.model_names, - value=0, - max_height=len(self.model_names), - max_width=max_width, - scroll_exit=True, - rely=7, - ) - self.add_widget_intelligent( - npyscreen.FixedText, - value="MODEL 2", - color="GOOD", - editable=False, - relx=max_width + 3 if horizontal_layout else None, - rely=6 if horizontal_layout else None, - ) - self.model2 = self.add_widget_intelligent( - npyscreen.SelectOne, - name="(2)", - values=self.model_names, - value=1, - max_height=len(self.model_names), - max_width=max_width, - relx=max_width + 3 if horizontal_layout else None, - rely=7 if horizontal_layout else None, - scroll_exit=True, - ) - self.add_widget_intelligent( - npyscreen.FixedText, - value="MODEL 3", - color="GOOD", - editable=False, - relx=max_width * 2 + 3 if horizontal_layout else None, - rely=6 if horizontal_layout else None, - ) - models_plus_none = self.model_names.copy() - models_plus_none.insert(0, "None") - self.model3 = self.add_widget_intelligent( - npyscreen.SelectOne, - name="(3)", - values=models_plus_none, - value=0, - max_height=len(self.model_names) + 1, - max_width=max_width, - scroll_exit=True, - relx=max_width * 2 + 3 if horizontal_layout else None, - rely=7 if horizontal_layout else None, - ) - for m in [self.model1, self.model2, self.model3]: - m.when_value_edited = self.models_changed - self.merged_model_name = self.add_widget_intelligent( - TextBox, - name="Name for merged model:", - labelColor="CONTROL", - max_height=3, - value="", - scroll_exit=True, - ) - self.force = self.add_widget_intelligent( - npyscreen.Checkbox, - name="Force merge of models created by different diffusers library versions", - labelColor="CONTROL", - value=True, - scroll_exit=True, - ) - self.nextrely += 1 - self.merge_method = self.add_widget_intelligent( - npyscreen.TitleSelectOne, - name="Merge Method:", - values=self.interpolations, - value=0, - labelColor="CONTROL", - max_height=len(self.interpolations) + 1, - scroll_exit=True, - ) - self.alpha = self.add_widget_intelligent( - FloatTitleSlider, - name="Weight (alpha) to assign to second and third models:", - out_of=1.0, - step=0.01, - lowest=0, - value=0.5, - labelColor="CONTROL", - scroll_exit=True, - ) - self.model1.editing = True - - def models_changed(self) -> None: - models = self.model1.values - selected_model1 = self.model1.value[0] - selected_model2 = self.model2.value[0] - selected_model3 = self.model3.value[0] - merged_model_name = f"{models[selected_model1]}+{models[selected_model2]}" - self.merged_model_name.value = merged_model_name - - if selected_model3 > 0: - self.merge_method.values = ["add_difference ( A+(B-C) )"] - self.merged_model_name.value += f"+{models[selected_model3 -1]}" # In model3 there is one more element in the list (None). So we have to subtract one. - else: - self.merge_method.values = self.interpolations - self.merge_method.value = 0 - - def on_ok(self) -> None: - if self.validate_field_values() and self.check_for_overwrite(): - self.parentApp.setNextForm(None) - self.editing = False - self.parentApp.merge_arguments = self.marshall_arguments() - npyscreen.notify("Starting the merge...") - else: - self.editing = True - - def on_cancel(self) -> None: - sys.exit(0) - - def marshall_arguments(self) -> dict: - model_keys = [x[0] for x in self.models] - models = [ - model_keys[self.model1.value[0]], - model_keys[self.model2.value[0]], - ] - if self.model3.value[0] > 0: - models.append(model_keys[self.model3.value[0] - 1]) - interp = "add_difference" - else: - interp = self.interpolations[self.merge_method.value[0]] - - args = { - "model_keys": models, - "alpha": self.alpha.value, - "interp": interp, - "force": self.force.value, - "merged_model_name": self.merged_model_name.value, - } - return args - - def check_for_overwrite(self) -> bool: - model_out = self.merged_model_name.value - if model_out not in self.model_names: - return True - else: - result: bool = npyscreen.notify_yes_no( - f"The chosen merged model destination, {model_out}, is already in use. Overwrite?" - ) - return result - - def validate_field_values(self) -> bool: - bad_fields = [] - model_names = self.model_names - selected_models = {model_names[self.model1.value[0]], model_names[self.model2.value[0]]} - if self.model3.value[0] > 0: - selected_models.add(model_names[self.model3.value[0] - 1]) - if len(selected_models) < 2: - bad_fields.append(f"Please select two or three DIFFERENT models to compare. You selected {selected_models}") - if len(bad_fields) > 0: - message = "The following problems were detected and must be corrected:" - for problem in bad_fields: - message += f"\n* {problem}" - npyscreen.notify_confirm(message) - return False - else: - return True - - def get_models(self, base_model: Optional[BaseModelType] = None) -> List[Tuple[str, str]]: # key to name - models = [ - (x.key, x.name) - for x in self.model_record_store.search_by_attr(model_type=ModelType.Main, base_model=base_model) - if x.format == ModelFormat("diffusers") - and hasattr(x, "variant") - and x.variant == ModelVariantType("normal") - ] - return sorted(models, key=lambda x: x[1]) - - def _populate_models(self, value: List[int]) -> None: - base_model = BASE_TYPES[value[0]][0] - self.models = self.get_models(base_model) - self.model_names = [x[1] for x in self.models] - - models_plus_none = self.model_names.copy() - models_plus_none.insert(0, "None") - self.model1.values = self.model_names - self.model2.values = self.model_names - self.model3.values = models_plus_none - - self.display() - - -# npyscreen is untyped and causes mypy to get naggy -class Mergeapp(npyscreen.NPSAppManaged): # type: ignore - def __init__(self, installer: ModelInstallServiceBase): - """Initialize the npyscreen application.""" - super().__init__() - self.installer = installer - - def onStart(self) -> None: - npyscreen.setTheme(npyscreen.Themes.ElegantTheme) - self.main = self.addForm("MAIN", mergeModelsForm, name="Merge Models Settings") - - -def run_gui(args: Namespace) -> None: - installer = initialize_installer(config) - mergeapp = Mergeapp(installer) - mergeapp.run() - merge_args = mergeapp.merge_arguments - merger = ModelMerger(installer) - merger.merge_diffusion_models_and_save(**merge_args) - logger.info(f'Models merged into new model: "{merge_args.merged_model_name}".') - - -def run_cli(args: Namespace) -> None: - assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1" - assert ( - args.model_names and len(args.model_names) >= 1 and len(args.model_names) <= 3 - ), "Please provide the --models argument to list 2 to 3 models to merge. Use --help for full usage." - - if not args.merged_model_name: - args.merged_model_name = "+".join(args.model_names) - logger.info(f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"') - - installer = initialize_installer(config) - store = installer.record_store - assert ( - len(store.search_by_attr(args.merged_model_name, args.base_model, ModelType.Main)) == 0 or args.clobber - ), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.' - - merger = ModelMerger(installer) - model_keys = [] - for name in args.model_names: - if len(name) == 32 and re.match(r"^[0-9a-f]$", name): - model_keys.append(name) - else: - models = store.search_by_attr( - model_name=name, model_type=ModelType.Main, base_model=BaseModelType(args.base_model) - ) - assert len(models) > 0, f"{name}: Unknown model" - assert len(models) < 2, f"{name}: More than one model by this name. Please specify the model key instead." - model_keys.append(models[0].key) - - merger.merge_diffusion_models_and_save( - alpha=args.alpha, - model_keys=model_keys, - merged_model_name=args.merged_model_name, - interp=args.interp, - force=args.force, - ) - logger.info(f'Models merged into new model: "{args.merged_model_name}".') - - -def main() -> None: - args = _parse_args() - if args.root_dir: - config.parse_args(["--root", str(args.root_dir)]) - else: - config.parse_args([]) - - try: - if args.front_end: - run_gui(args) - else: - run_cli(args) - except widget.NotEnoughSpaceForWidget as e: - if str(e).startswith("Height of 1 allocated"): - logger.error("You need to have at least two diffusers models defined in models.yaml in order to merge") - else: - logger.error("Not enough room for the user interface. Try making this window larger.") - sys.exit(-1) - except Exception as e: - logger.error(str(e)) - sys.exit(-1) - except KeyboardInterrupt: - sys.exit(-1) - - -if __name__ == "__main__": - main() diff --git a/pyproject.toml b/pyproject.toml index 86a692f9843..243b0b1f21e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -135,11 +135,9 @@ dependencies = [ # full commands "invokeai-configure" = "invokeai.frontend.install.invokeai_configure:invokeai_configure" -"invokeai-merge" = "invokeai.frontend.merge:invokeai_merge_diffusers" -"invokeai-merge2" = "invokeai.frontend.merge.merge_diffusers2:main" +"invokeai-merge" = "invokeai.frontend.merge.merge_diffusers:main" "invokeai-ti" = "invokeai.frontend.training:invokeai_textual_inversion" "invokeai-model-install" = "invokeai.frontend.install.model_install:main" -"invokeai-model-install2" = "invokeai.frontend.install.model_install2:main" # will eventually be renamed to invokeai-model-install "invokeai-migrate3" = "invokeai.backend.install.migrate_to_3:main" "invokeai-update" = "invokeai.frontend.install.invokeai_update:main" "invokeai-metadata" = "invokeai.backend.image_util.invoke_metadata:main" @@ -244,7 +242,7 @@ module = [ "invokeai.app.services.invocation_stats.invocation_stats_default", "invokeai.app.services.model_manager.model_manager_base", "invokeai.app.services.model_manager.model_manager_default", - "invokeai.app.services.model_records.model_records_sql", + "invokeai.app.services.model_manager.store.model_records_sql", "invokeai.app.util.controlnet_utils", "invokeai.backend.image_util.txt2mask", "invokeai.backend.image_util.safety_checker", diff --git a/tests/aa_nodes/test_graph_execution_state.py b/tests/aa_nodes/test_graph_execution_state.py index fab1fa4598f..80308b57aff 100644 --- a/tests/aa_nodes/test_graph_execution_state.py +++ b/tests/aa_nodes/test_graph_execution_state.py @@ -64,9 +64,7 @@ def mock_services() -> InvocationServices: latents=None, # type: ignore logger=logging, # type: ignore model_manager=None, # type: ignore - model_records=None, # type: ignore download_queue=None, # type: ignore - model_install=None, # type: ignore names=None, # type: ignore performance_statistics=InvocationStatsService(), processor=DefaultInvocationProcessor(), diff --git a/tests/aa_nodes/test_invoker.py b/tests/aa_nodes/test_invoker.py index 2ae4eab58a0..7f9ca8152c1 100644 --- a/tests/aa_nodes/test_invoker.py +++ b/tests/aa_nodes/test_invoker.py @@ -1,4 +1,5 @@ import logging +from unittest.mock import Mock import pytest @@ -65,10 +66,8 @@ def mock_services() -> InvocationServices: invocation_cache=MemoryInvocationCache(max_cache_size=0), latents=None, # type: ignore logger=logging, # type: ignore - model_manager=None, # type: ignore - model_records=None, # type: ignore + model_manager=Mock(), # type: ignore download_queue=None, # type: ignore - model_install=None, # type: ignore names=None, # type: ignore performance_statistics=InvocationStatsService(), processor=DefaultInvocationProcessor(), diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index 5694432ebdc..55f7e865410 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -20,7 +20,7 @@ ) from invokeai.app.services.model_records import UnknownModelException from invokeai.backend.model_manager.config import BaseModelType, ModelFormat, ModelType -from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403 +from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 OS = platform.uname().system diff --git a/tests/app/services/model_records/test_model_records_sql.py b/tests/app/services/model_records/test_model_records_sql.py index 46afe0105b5..57515ac81b1 100644 --- a/tests/app/services/model_records/test_model_records_sql.py +++ b/tests/app/services/model_records/test_model_records_sql.py @@ -8,6 +8,7 @@ import pytest from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.model_metadata import ModelMetadataStoreSQL from invokeai.app.services.model_records import ( DuplicateModelException, ModelRecordOrderBy, @@ -25,7 +26,7 @@ ) from invokeai.backend.model_manager.metadata import BaseMetadata from invokeai.backend.util.logging import InvokeAILogger -from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403 +from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 from tests.fixtures.sqlite_database import create_mock_sqlite_database @@ -36,7 +37,7 @@ def store( config = InvokeAIAppConfig(root=datadir) logger = InvokeAILogger.get_logger(config=config) db = create_mock_sqlite_database(config, logger) - return ModelRecordServiceSQL(db) + return ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db)) def example_config() -> TextualInversionConfig: diff --git a/tests/backend/ip_adapter/test_ip_adapter.py b/tests/backend/ip_adapter/test_ip_adapter.py index 6a3ec510a2c..9ed3c9bc507 100644 --- a/tests/backend/ip_adapter/test_ip_adapter.py +++ b/tests/backend/ip_adapter/test_ip_adapter.py @@ -2,7 +2,7 @@ import torch from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher -from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType +from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType from invokeai.backend.util.test_utils import install_and_load_model diff --git a/tests/backend/model_manager_2/data/invokeai_root/README b/tests/backend/model_manager/data/invokeai_root/README similarity index 100% rename from tests/backend/model_manager_2/data/invokeai_root/README rename to tests/backend/model_manager/data/invokeai_root/README diff --git a/tests/backend/model_manager_2/data/invokeai_root/configs/stable-diffusion/v1-inference.yaml b/tests/backend/model_manager/data/invokeai_root/configs/stable-diffusion/v1-inference.yaml similarity index 100% rename from tests/backend/model_manager_2/data/invokeai_root/configs/stable-diffusion/v1-inference.yaml rename to tests/backend/model_manager/data/invokeai_root/configs/stable-diffusion/v1-inference.yaml diff --git a/tests/backend/model_manager_2/data/invokeai_root/databases/README b/tests/backend/model_manager/data/invokeai_root/databases/README similarity index 100% rename from tests/backend/model_manager_2/data/invokeai_root/databases/README rename to tests/backend/model_manager/data/invokeai_root/databases/README diff --git a/tests/backend/model_manager_2/data/invokeai_root/models/README b/tests/backend/model_manager/data/invokeai_root/models/README similarity index 100% rename from tests/backend/model_manager_2/data/invokeai_root/models/README rename to tests/backend/model_manager/data/invokeai_root/models/README diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/model_index.json b/tests/backend/model_manager/data/test_files/test-diffusers-main/model_index.json similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/model_index.json rename to tests/backend/model_manager/data/test_files/test-diffusers-main/model_index.json diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/scheduler/scheduler_config.json b/tests/backend/model_manager/data/test_files/test-diffusers-main/scheduler/scheduler_config.json similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/scheduler/scheduler_config.json rename to tests/backend/model_manager/data/test_files/test-diffusers-main/scheduler/scheduler_config.json diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/text_encoder/config.json b/tests/backend/model_manager/data/test_files/test-diffusers-main/text_encoder/config.json similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/text_encoder/config.json rename to tests/backend/model_manager/data/test_files/test-diffusers-main/text_encoder/config.json diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/text_encoder/model.fp16.safetensors b/tests/backend/model_manager/data/test_files/test-diffusers-main/text_encoder/model.fp16.safetensors similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/text_encoder/model.fp16.safetensors rename to tests/backend/model_manager/data/test_files/test-diffusers-main/text_encoder/model.fp16.safetensors diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/text_encoder/model.safetensors b/tests/backend/model_manager/data/test_files/test-diffusers-main/text_encoder/model.safetensors similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/text_encoder/model.safetensors rename to tests/backend/model_manager/data/test_files/test-diffusers-main/text_encoder/model.safetensors diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/text_encoder_2/config.json b/tests/backend/model_manager/data/test_files/test-diffusers-main/text_encoder_2/config.json similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/text_encoder_2/config.json rename to tests/backend/model_manager/data/test_files/test-diffusers-main/text_encoder_2/config.json diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/text_encoder_2/model.fp16.safetensors b/tests/backend/model_manager/data/test_files/test-diffusers-main/text_encoder_2/model.fp16.safetensors similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/text_encoder_2/model.fp16.safetensors rename to tests/backend/model_manager/data/test_files/test-diffusers-main/text_encoder_2/model.fp16.safetensors diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/text_encoder_2/model.safetensors b/tests/backend/model_manager/data/test_files/test-diffusers-main/text_encoder_2/model.safetensors similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/text_encoder_2/model.safetensors rename to tests/backend/model_manager/data/test_files/test-diffusers-main/text_encoder_2/model.safetensors diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer/merges.txt b/tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer/merges.txt similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer/merges.txt rename to tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer/merges.txt diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer/special_tokens_map.json b/tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer/special_tokens_map.json similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer/special_tokens_map.json rename to tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer/special_tokens_map.json diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer/tokenizer_config.json b/tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer/tokenizer_config.json similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer/tokenizer_config.json rename to tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer/tokenizer_config.json diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer/vocab.json b/tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer/vocab.json similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer/vocab.json rename to tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer/vocab.json diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer_2/merges.txt b/tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer_2/merges.txt similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer_2/merges.txt rename to tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer_2/merges.txt diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer_2/special_tokens_map.json b/tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer_2/special_tokens_map.json similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer_2/special_tokens_map.json rename to tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer_2/special_tokens_map.json diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer_2/tokenizer_config.json b/tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer_2/tokenizer_config.json similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer_2/tokenizer_config.json rename to tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer_2/tokenizer_config.json diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer_2/vocab.json b/tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer_2/vocab.json similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer_2/vocab.json rename to tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer_2/vocab.json diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/unet/config.json b/tests/backend/model_manager/data/test_files/test-diffusers-main/unet/config.json similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/unet/config.json rename to tests/backend/model_manager/data/test_files/test-diffusers-main/unet/config.json diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/unet/diffusion_pytorch_model.fp16.safetensors b/tests/backend/model_manager/data/test_files/test-diffusers-main/unet/diffusion_pytorch_model.fp16.safetensors similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/unet/diffusion_pytorch_model.fp16.safetensors rename to tests/backend/model_manager/data/test_files/test-diffusers-main/unet/diffusion_pytorch_model.fp16.safetensors diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/unet/diffusion_pytorch_model.safetensors b/tests/backend/model_manager/data/test_files/test-diffusers-main/unet/diffusion_pytorch_model.safetensors similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/unet/diffusion_pytorch_model.safetensors rename to tests/backend/model_manager/data/test_files/test-diffusers-main/unet/diffusion_pytorch_model.safetensors diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/vae/config.json b/tests/backend/model_manager/data/test_files/test-diffusers-main/vae/config.json similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/vae/config.json rename to tests/backend/model_manager/data/test_files/test-diffusers-main/vae/config.json diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/vae/diffusion_pytorch_model.fp16.safetensors b/tests/backend/model_manager/data/test_files/test-diffusers-main/vae/diffusion_pytorch_model.fp16.safetensors similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/vae/diffusion_pytorch_model.fp16.safetensors rename to tests/backend/model_manager/data/test_files/test-diffusers-main/vae/diffusion_pytorch_model.fp16.safetensors diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/vae/diffusion_pytorch_model.safetensors b/tests/backend/model_manager/data/test_files/test-diffusers-main/vae/diffusion_pytorch_model.safetensors similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/vae/diffusion_pytorch_model.safetensors rename to tests/backend/model_manager/data/test_files/test-diffusers-main/vae/diffusion_pytorch_model.safetensors diff --git a/tests/backend/model_manager_2/data/test_files/test_embedding.safetensors b/tests/backend/model_manager/data/test_files/test_embedding.safetensors similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test_embedding.safetensors rename to tests/backend/model_manager/data/test_files/test_embedding.safetensors diff --git a/tests/backend/model_manager/model_loading/test_model_load.py b/tests/backend/model_manager/model_loading/test_model_load.py new file mode 100644 index 00000000000..c1fde504eae --- /dev/null +++ b/tests/backend/model_manager/model_loading/test_model_load.py @@ -0,0 +1,30 @@ +""" +Test model loading +""" + +from pathlib import Path + +from invokeai.app.services.model_manager import ModelManagerServiceBase +from invokeai.backend.textual_inversion import TextualInversionModelRaw +from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 + + +def test_loading(mm2_model_manager: ModelManagerServiceBase, embedding_file: Path): + store = mm2_model_manager.store + matches = store.search_by_attr(model_name="test_embedding") + assert len(matches) == 0 + key = mm2_model_manager.install.register_path(embedding_file) + loaded_model = mm2_model_manager.load_model_by_config(store.get_model(key)) + assert loaded_model is not None + assert loaded_model.config.key == key + with loaded_model as model: + assert isinstance(model, TextualInversionModelRaw) + loaded_model_2 = mm2_model_manager.load_model_by_key(key) + assert loaded_model.config.key == loaded_model_2.config.key + + loaded_model_3 = mm2_model_manager.load_model_by_attr( + model_name=loaded_model.config.name, + model_type=loaded_model.config.type, + base_model=loaded_model.config.base, + ) + assert loaded_model.config.key == loaded_model_3.config.key diff --git a/tests/backend/model_manager_2/model_manager_2_fixtures.py b/tests/backend/model_manager/model_manager_fixtures.py similarity index 77% rename from tests/backend/model_manager_2/model_manager_2_fixtures.py rename to tests/backend/model_manager/model_manager_fixtures.py index d6d091befea..df54e2f9267 100644 --- a/tests/backend/model_manager_2/model_manager_2_fixtures.py +++ b/tests/backend/model_manager/model_manager_fixtures.py @@ -7,22 +7,26 @@ import pytest from pydantic import BaseModel +from pytest import FixtureRequest from requests.sessions import Session from requests_testadapter import TestAdapter, TestSession from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.download import DownloadQueueService +from invokeai.app.services.download import DownloadQueueService, DownloadQueueServiceBase from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase -from invokeai.app.services.model_records import ModelRecordServiceSQL +from invokeai.app.services.model_load import ModelLoadService, ModelLoadServiceBase +from invokeai.app.services.model_manager import ModelManagerService, ModelManagerServiceBase +from invokeai.app.services.model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL +from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL from invokeai.backend.model_manager.config import ( BaseModelType, ModelFormat, ModelType, ) -from invokeai.backend.model_manager.metadata import ModelMetadataStore +from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache from invokeai.backend.util.logging import InvokeAILogger -from tests.backend.model_manager_2.model_metadata.metadata_examples import ( +from tests.backend.model_manager.model_metadata.metadata_examples import ( RepoCivitaiModelMetadata1, RepoCivitaiVersionMetadata1, RepoHFMetadata1, @@ -85,15 +89,76 @@ def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig: app_config = InvokeAIAppConfig( root=mm2_root_dir, models_dir=mm2_root_dir / "models", + log_level="info", ) return app_config @pytest.fixture -def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL: +def mm2_download_queue(mm2_session: Session, request: FixtureRequest) -> DownloadQueueServiceBase: + download_queue = DownloadQueueService(requests_session=mm2_session) + download_queue.start() + + def stop_queue() -> None: + download_queue.stop() + + request.addfinalizer(stop_queue) + return download_queue + + +@pytest.fixture +def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStoreBase: + return mm2_record_store.metadata_store + + +@pytest.fixture +def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase: + ram_cache = ModelCache( + logger=InvokeAILogger.get_logger(), + max_cache_size=mm2_app_config.ram_cache_size, + max_vram_cache_size=mm2_app_config.vram_cache_size, + ) + convert_cache = ModelConvertCache(mm2_app_config.models_convert_cache_path) + return ModelLoadService( + app_config=mm2_app_config, + ram_cache=ram_cache, + convert_cache=convert_cache, + ) + + +@pytest.fixture +def mm2_installer( + mm2_app_config: InvokeAIAppConfig, + mm2_download_queue: DownloadQueueServiceBase, + mm2_session: Session, + request: FixtureRequest, +) -> ModelInstallServiceBase: + logger = InvokeAILogger.get_logger() + db = create_mock_sqlite_database(mm2_app_config, logger) + events = DummyEventService() + store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db)) + + installer = ModelInstallService( + app_config=mm2_app_config, + record_store=store, + download_queue=mm2_download_queue, + event_bus=events, + session=mm2_session, + ) + installer.start() + + def stop_installer() -> None: + installer.stop() + + request.addfinalizer(stop_installer) + return installer + + +@pytest.fixture +def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBase: logger = InvokeAILogger.get_logger(config=mm2_app_config) db = create_mock_sqlite_database(mm2_app_config, logger) - store = ModelRecordServiceSQL(db) + store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db)) # add five simple config records to the database raw1 = { "path": "/tmp/foo1", @@ -152,15 +217,16 @@ def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL @pytest.fixture -def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStore: - db = mm2_record_store._db # to ensure we are sharing the same database - return ModelMetadataStore(db) +def mm2_model_manager( + mm2_record_store: ModelRecordServiceBase, mm2_installer: ModelInstallServiceBase, mm2_loader: ModelLoadServiceBase +) -> ModelManagerServiceBase: + return ModelManagerService(store=mm2_record_store, install=mm2_installer, load=mm2_loader) @pytest.fixture def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session: """This fixtures defines a series of mock URLs for testing download and installation.""" - sess = TestSession() + sess: Session = TestSession() sess.mount( "https://test.com/missing_model.safetensors", TestAdapter( @@ -240,26 +306,3 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session: ), ) return sess - - -@pytest.fixture -def mm2_installer(mm2_app_config: InvokeAIAppConfig, mm2_session: Session) -> ModelInstallServiceBase: - logger = InvokeAILogger.get_logger() - db = create_mock_sqlite_database(mm2_app_config, logger) - events = DummyEventService() - store = ModelRecordServiceSQL(db) - metadata_store = ModelMetadataStore(db) - - download_queue = DownloadQueueService(requests_session=mm2_session) - download_queue.start() - - installer = ModelInstallService( - app_config=mm2_app_config, - record_store=store, - download_queue=download_queue, - metadata_store=metadata_store, - event_bus=events, - session=mm2_session, - ) - installer.start() - return installer diff --git a/tests/backend/model_manager_2/model_metadata/metadata_examples.py b/tests/backend/model_manager/model_metadata/metadata_examples.py similarity index 100% rename from tests/backend/model_manager_2/model_metadata/metadata_examples.py rename to tests/backend/model_manager/model_metadata/metadata_examples.py diff --git a/tests/backend/model_manager_2/model_metadata/test_model_metadata.py b/tests/backend/model_manager/model_metadata/test_model_metadata.py similarity index 96% rename from tests/backend/model_manager_2/model_metadata/test_model_metadata.py rename to tests/backend/model_manager/model_metadata/test_model_metadata.py index 5a2ec937673..09b18916d38 100644 --- a/tests/backend/model_manager_2/model_metadata/test_model_metadata.py +++ b/tests/backend/model_manager/model_metadata/test_model_metadata.py @@ -8,6 +8,7 @@ from pydantic.networks import HttpUrl from requests.sessions import Session +from invokeai.app.services.model_metadata import ModelMetadataStoreBase from invokeai.backend.model_manager.config import ModelRepoVariant from invokeai.backend.model_manager.metadata import ( CivitaiMetadata, @@ -15,14 +16,13 @@ CommercialUsage, HuggingFaceMetadata, HuggingFaceMetadataFetch, - ModelMetadataStore, UnknownMetadataException, ) from invokeai.backend.model_manager.util import select_hf_files -from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403 +from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 -def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStore) -> None: +def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStoreBase) -> None: tags = {"text-to-image", "diffusers"} input_metadata = HuggingFaceMetadata( name="sdxl-vae", @@ -40,7 +40,7 @@ def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStore) -> None: assert mm2_metadata_store.list_tags() == tags -def test_metadata_store_update(mm2_metadata_store: ModelMetadataStore) -> None: +def test_metadata_store_update(mm2_metadata_store: ModelMetadataStoreBase) -> None: input_metadata = HuggingFaceMetadata( name="sdxl-vae", author="stabilityai", @@ -57,7 +57,7 @@ def test_metadata_store_update(mm2_metadata_store: ModelMetadataStore) -> None: assert input_metadata == output_metadata -def test_metadata_search(mm2_metadata_store: ModelMetadataStore) -> None: +def test_metadata_search(mm2_metadata_store: ModelMetadataStoreBase) -> None: metadata1 = HuggingFaceMetadata( name="sdxl-vae", author="stabilityai", diff --git a/tests/backend/model_management/test_libc_util.py b/tests/backend/model_manager/test_libc_util.py similarity index 88% rename from tests/backend/model_management/test_libc_util.py rename to tests/backend/model_manager/test_libc_util.py index e13a2fd3a2f..4309dc7c34c 100644 --- a/tests/backend/model_management/test_libc_util.py +++ b/tests/backend/model_manager/test_libc_util.py @@ -1,6 +1,6 @@ import pytest -from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2 +from invokeai.backend.model_manager.util.libc_util import LibcUtil, Struct_mallinfo2 def test_libc_util_mallinfo2(): diff --git a/tests/backend/model_management/test_lora.py b/tests/backend/model_manager/test_lora.py similarity index 96% rename from tests/backend/model_management/test_lora.py rename to tests/backend/model_manager/test_lora.py index 14bcc87c892..114a4cfdcff 100644 --- a/tests/backend/model_management/test_lora.py +++ b/tests/backend/model_manager/test_lora.py @@ -5,8 +5,8 @@ import pytest import torch -from invokeai.backend.model_management.lora import ModelPatcher -from invokeai.backend.model_management.models.lora import LoRALayer, LoRAModelRaw +from invokeai.backend.lora import LoRALayer, LoRAModelRaw +from invokeai.backend.model_patcher import ModelPatcher @pytest.mark.parametrize( diff --git a/tests/backend/model_management/test_memory_snapshot.py b/tests/backend/model_manager/test_memory_snapshot.py similarity index 87% rename from tests/backend/model_management/test_memory_snapshot.py rename to tests/backend/model_manager/test_memory_snapshot.py index 216cd62171d..d31ae79b668 100644 --- a/tests/backend/model_management/test_memory_snapshot.py +++ b/tests/backend/model_manager/test_memory_snapshot.py @@ -1,7 +1,7 @@ import pytest -from invokeai.backend.model_management.libc_util import Struct_mallinfo2 -from invokeai.backend.model_management.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff +from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff +from invokeai.backend.model_manager.util.libc_util import Struct_mallinfo2 def test_memory_snapshot_capture(): @@ -26,6 +26,7 @@ def test_memory_snapshot_capture(): def test_get_pretty_snapshot_diff(snapshot_1, snapshot_2): """Test that get_pretty_snapshot_diff() works with various combinations of missing MemorySnapshot fields.""" msg = get_pretty_snapshot_diff(snapshot_1, snapshot_2) + print(msg) expected_lines = 0 if snapshot_1 is not None and snapshot_2 is not None: diff --git a/tests/backend/model_management/test_model_load_optimization.py b/tests/backend/model_manager/test_model_load_optimization.py similarity index 96% rename from tests/backend/model_management/test_model_load_optimization.py rename to tests/backend/model_manager/test_model_load_optimization.py index a4fe1dd5974..f627f3a2982 100644 --- a/tests/backend/model_management/test_model_load_optimization.py +++ b/tests/backend/model_manager/test_model_load_optimization.py @@ -1,7 +1,7 @@ import pytest import torch -from invokeai.backend.model_management.model_load_optimizations import _no_op, skip_torch_weight_init +from invokeai.backend.model_manager.load.optimizations import _no_op, skip_torch_weight_init @pytest.mark.parametrize( diff --git a/tests/backend/model_manager_2/util/test_hf_model_select.py b/tests/backend/model_manager/util/test_hf_model_select.py similarity index 99% rename from tests/backend/model_manager_2/util/test_hf_model_select.py rename to tests/backend/model_manager/util/test_hf_model_select.py index f14d9a6823a..5bef9cb2e19 100644 --- a/tests/backend/model_manager_2/util/test_hf_model_select.py +++ b/tests/backend/model_manager/util/test_hf_model_select.py @@ -192,6 +192,7 @@ def sdxl_base_files() -> List[Path]: "text_encoder/model.onnx", "text_encoder_2/config.json", "text_encoder_2/model.onnx", + "text_encoder_2/model.onnx_data", "tokenizer/merges.txt", "tokenizer/special_tokens_map.json", "tokenizer/tokenizer_config.json", @@ -202,6 +203,7 @@ def sdxl_base_files() -> List[Path]: "tokenizer_2/vocab.json", "unet/config.json", "unet/model.onnx", + "unet/model.onnx_data", "vae_decoder/config.json", "vae_decoder/model.onnx", "vae_encoder/config.json", diff --git a/tests/conftest.py b/tests/conftest.py index 6e7d559be44..1c816002296 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,2 @@ # conftest.py is a special pytest file. Fixtures defined in this file will be accessible to all tests in this directory # without needing to explicitly import them. (https://docs.pytest.org/en/6.2.x/fixture.html) - - -# We import the model_installer and torch_device fixtures here so that they can be used by all tests. Flake8 does not -# play well with fixtures (F401 and F811), so this is cleaner than importing in all files that use these fixtures. -from invokeai.backend.util.test_utils import model_installer, torch_device # noqa: F401 diff --git a/tests/test_model_manager.py b/tests/test_model_manager.py deleted file mode 100644 index 3e48c7ed6fc..00000000000 --- a/tests/test_model_manager.py +++ /dev/null @@ -1,47 +0,0 @@ -from pathlib import Path - -import pytest - -from invokeai.app.services.config.config_default import InvokeAIAppConfig -from invokeai.backend import BaseModelType, ModelManager, ModelType, SubModelType - -BASIC_MODEL_NAME = ("SDXL base", BaseModelType.StableDiffusionXL, ModelType.Main) -VAE_OVERRIDE_MODEL_NAME = ("SDXL with VAE", BaseModelType.StableDiffusionXL, ModelType.Main) -VAE_NULL_OVERRIDE_MODEL_NAME = ("SDXL with empty VAE", BaseModelType.StableDiffusionXL, ModelType.Main) - - -@pytest.fixture -def model_manager(datadir) -> ModelManager: - InvokeAIAppConfig.get_config(root=datadir) - return ModelManager(datadir / "configs" / "relative_sub.models.yaml") - - -def test_get_model_names(model_manager: ModelManager): - names = model_manager.model_names() - assert names[:2] == [BASIC_MODEL_NAME, VAE_OVERRIDE_MODEL_NAME] - - -def test_get_model_path_for_diffusers(model_manager: ModelManager, datadir: Path): - model_config = model_manager._get_model_config(BASIC_MODEL_NAME[1], BASIC_MODEL_NAME[0], BASIC_MODEL_NAME[2]) - top_model_path, is_override = model_manager._get_model_path(model_config) - expected_model_path = datadir / "models" / "sdxl" / "main" / "SDXL base 1_0" - assert top_model_path == expected_model_path - assert not is_override - - -def test_get_model_path_for_overridden_vae(model_manager: ModelManager, datadir: Path): - model_config = model_manager._get_model_config( - VAE_OVERRIDE_MODEL_NAME[1], VAE_OVERRIDE_MODEL_NAME[0], VAE_OVERRIDE_MODEL_NAME[2] - ) - vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae) - expected_vae_path = datadir / "models" / "sdxl" / "vae" / "sdxl-vae-fp16-fix" - assert vae_model_path == expected_vae_path - assert is_override - - -def test_get_model_path_for_null_overridden_vae(model_manager: ModelManager, datadir: Path): - model_config = model_manager._get_model_config( - VAE_NULL_OVERRIDE_MODEL_NAME[1], VAE_NULL_OVERRIDE_MODEL_NAME[0], VAE_NULL_OVERRIDE_MODEL_NAME[2] - ) - vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae) - assert not is_override diff --git a/tests/test_model_probe.py b/tests/test_model_probe.py index 248b7d602fd..be823e2be9f 100644 --- a/tests/test_model_probe.py +++ b/tests/test_model_probe.py @@ -2,8 +2,8 @@ import pytest -from invokeai.backend import BaseModelType -from invokeai.backend.model_management.model_probe import VaeFolderProbe +from invokeai.backend.model_manager import BaseModelType, ModelRepoVariant +from invokeai.backend.model_manager.probe import VaeFolderProbe @pytest.mark.parametrize( @@ -20,3 +20,11 @@ def test_get_base_type(vae_path: str, expected_type: BaseModelType, datadir: Pat probe = VaeFolderProbe(sd1_vae_path) base_type = probe.get_base_type() assert base_type == expected_type + repo_variant = probe.get_repo_variant() + assert repo_variant == ModelRepoVariant.DEFAULT + + +def test_repo_variant(datadir: Path): + probe = VaeFolderProbe(datadir / "vae" / "taesdxl-fp16") + repo_variant = probe.get_repo_variant() + assert repo_variant == ModelRepoVariant.FP16 diff --git a/tests/test_model_probe/vae/taesdxl-fp16/config.json b/tests/test_model_probe/vae/taesdxl-fp16/config.json new file mode 100644 index 00000000000..62f01c3eb44 --- /dev/null +++ b/tests/test_model_probe/vae/taesdxl-fp16/config.json @@ -0,0 +1,37 @@ +{ + "_class_name": "AutoencoderTiny", + "_diffusers_version": "0.20.0.dev0", + "act_fn": "relu", + "decoder_block_out_channels": [ + 64, + 64, + 64, + 64 + ], + "encoder_block_out_channels": [ + 64, + 64, + 64, + 64 + ], + "force_upcast": false, + "in_channels": 3, + "latent_channels": 4, + "latent_magnitude": 3, + "latent_shift": 0.5, + "num_decoder_blocks": [ + 3, + 3, + 3, + 1 + ], + "num_encoder_blocks": [ + 1, + 3, + 3, + 3 + ], + "out_channels": 3, + "scaling_factor": 1.0, + "upsampling_scaling_factor": 2 +} diff --git a/tests/test_model_probe/vae/taesdxl-fp16/diffusion_pytorch_model.fp16.safetensors b/tests/test_model_probe/vae/taesdxl-fp16/diffusion_pytorch_model.fp16.safetensors new file mode 100644 index 00000000000..e69de29bb2d