diff --git a/docs/contributing/MODEL_MANAGER.md b/docs/contributing/MODEL_MANAGER.md new file mode 100644 index 00000000000..5529d1feeb7 --- /dev/null +++ b/docs/contributing/MODEL_MANAGER.md @@ -0,0 +1,1214 @@ +# Introduction to the Model Manager V2 + +The Model Manager is responsible for organizing the various machine +learning models used by InvokeAI. It consists of a series of +interdependent services that together handle the full lifecycle of a +model. These are the: + +* _ModelRecordServiceBase_ Responsible for managing model metadata and + configuration information. Among other things, the record service + tracks the type of the model, its provenance, and where it can be + found on disk. + +* _ModelLoadServiceBase_ Responsible for loading a model from disk + into RAM and VRAM and getting it ready for inference/training. + +* _DownloadQueueServiceBase_ A multithreaded downloader responsible + for downloading models from a remote source to disk. The download + queue has special methods for downloading repo_id folders from + Hugging Face, as well as discriminating among model versions in + Civitai, but can be used for arbitrary content. + +* _ModelInstallServiceBase_ A service for installing models to + disk. It uses `DownloadQueueServiceBase` to download models and + their metadata, and `ModelRecordServiceBase` to store that + information. It is also responsible for managing the InvokeAI + `models` directory and its contents. + +## Location of the Code + +All four of these services can be found in +`invokeai/app/services` in the following files: + +* `invokeai/app/services/model_record_service.py` +* `invokeai/app/services/download_manager.py` (needs a name change) +* `invokeai/app/services/model_loader_service.py` +* `invokeai/app/services/model_install_service.py` + +With the exception of the install service, each of these is a thin +shell around a corresponding implementation located in +`invokeai/backend/model_manager`. The main difference between the +modules found in app services and those in the backend folder is that +the former add support for event reporting and are more tied to the +needs of the InvokeAI API. + +Code related to the FastAPI web API can be found in +`invokeai/app/api/routers/models.py`. + +*** + +## What's in a Model? The ModelRecordService + +The `ModelRecordService` manages the model's metadata. It supports a +hierarchy of pydantic metadata "config" objects, which become +increasingly specialized to support particular model types. + +### ModelConfigBase + +All model metadata classes inherit from this pydantic class. it +provides the following fields: + +| **Field Name** | **Type** | **Description** | +|----------------|-----------------|------------------| +| `key` | str | Unique identifier for the model | +| `name` | str | Name of the model (not unique) | +| `model_type` | ModelType | The type of the model | +| `model_format` | ModelFormat | The format of the model (e.g. "diffusers"); also used as a Union discriminator | +| `base_model` | BaseModelType | The base model that the model is compatible with | +| `path` | str | Location of model on disk | +| `hash` | str | Most recent hash of the model's contents | +| `description` | str | Human-readable description of the model (optional) | +| `author` | str | Name of the model's author (optional) | +| `license` | str | Model's licensing model, as reported by the download source (optional) | +| `source` | str | Model's source URL or repo id (optional) | +| `thumbnail_url` | str | A thumbnail preview of model output, as reported by its source (optional) | +| `tags` | List[str] | A list of tags associated with the model, as reported by its source (optional) | + +The `key` is a unique 32-character hash which is originally obtained +by sampling several parts of the model's files using the `imohash` +library. If the model is altered within InvokeAI (typically by +converting a checkpoint to a diffusers model) the key will remain the +same. The `hash` field holds the current hash of the model. It starts +out being the same as `key`, but may diverge. + +`ModelType`, `ModelFormat` and `BaseModelType` are string enums that +are defined in `invokeai.backend.model_manager.config`. They are also +imported by, and can be reexported from, +`invokeai.app.services.model_record_service`: + +``` +from invokeai.app.services.model_record_service import ModelType, ModelFormat, BaseModelType +``` + +The `path` field can be absolute or relative. If relative, it is taken +to be relative to the `models_dir` setting in the user's +`invokeai.yaml` file. + + +### CheckpointConfig + +This adds support for checkpoint configurations, and adds the +following field: + +| **Field Name** | **Type** | **Description** | +|----------------|-----------------|------------------| +| `config` | str | Path to the checkpoint's config file | + +`config` is the path to the checkpoint's config file. If relative, it +is taken to be relative to the InvokeAI root directory +(e.g. `configs/stable-diffusion/v1-inference.yaml`) + +### MainConfig + +This adds support for "main" Stable Diffusion models, and adds these +fields: + +| **Field Name** | **Type** | **Description** | +|----------------|-----------------|------------------| +| `vae` | str | Path to a VAE to use instead of the burnt-in one | +| `variant` | ModelVariantType| Model variant type, such as "inpainting" | + +`vae` can be an absolute or relative path. If relative, its base is +taken to be the `models_dir` directory. + +`variant` is an enumerated string class with values `normal`, +`inpaint` and `depth`. If needed, it can be imported if needed from +either `invokeai.app.services.model_record_service` or +`invokeai.backend.model_manager.config`. + +### ONNXSD2Config + +| **Field Name** | **Type** | **Description** | +|----------------|-----------------|------------------| +| `prediction_type` | SchedulerPredictionType | Scheduler prediction type to use, e.g. "epsilon" | +| `upcast_attention` | bool | Model requires its attention module to be upcast | + +The `SchedulerPredictionType` enum can be imported from either +`invokeai.app.services.model_record_service` or +`invokeai.backend.model_manager.config`. + +### Other config classes + +There are a series of such classes each discriminated by their +`ModelFormat`, including `LoRAConfig`, `IPAdapterConfig`, and so +forth. These are rarely needed outside the model manager's internal +code, but available in `invokeai.backend.model_manager.config` if +needed. There is also a Union of all ModelConfig classes, called +`AnyModelConfig` that can be imported from the same file. + +### Limitations of the Data Model + +The config hierarchy has a major limitation in its handling of the +base model type. Each model can only be compatible with one base +model, which breaks down in the event of models that are compatible +with two or more base models. For example, SD-1 VAEs also work with +SD-2 models. A partial workaround is to use `BaseModelType.Any`, which +indicates that the model is compatible with any of the base +models. This works OK for some models, such as the IP Adapter image +encoders, but is an all-or-nothing proposition. + +Another issue is that the config class hierarchy is paralleled to some +extent by a `ModelBase` class hierarchy defined in +`invokeai.backend.model_manager.models.base` and its subclasses. These +are classes representing the models after they are loaded into RAM and +include runtime information such as load status and bytes used. Some +of the fields, including `name`, `model_type` and `base_model`, are +shared between `ModelConfigBase` and `ModelBase`, and this is a +potential source of confusion. + +** TO DO: ** The `ModelBase` code needs to be revised to reduce the +duplication of similar classes and to support using the `key` as the +primary model identifier. + +## Reading and Writing Model Configuration Records + +The `ModelRecordService` provides the ability to retrieve model +configuration records from SQL or YAML databases, update them, and +write them back. + +A application-wide `ModelRecordService` is created during API +initialization and can be retrieved within an invocation from the +`InvocationContext` object: + +``` +store = context.services.model_record_store +``` + +or from elsewhere in the code by accessing +`ApiDependencies.invoker.services.model_record_store`. + +### Creating a `ModelRecordService` + +To create a new `ModelRecordService` database or open an existing one, +you can directly create either a `ModelRecordServiceSQL` or a +`ModelRecordServiceFile` object: + +``` +from invokeai.app.services.model_record_service import ModelRecordServiceSQL, ModelRecordServiceFile + +store = ModelRecordServiceSQL.from_connection(connection, lock) +store = ModelRecordServiceSQL.from_db_file('/path/to/sqlite_database.db') +store = ModelRecordServiceFile.from_db_file('/path/to/database.yaml') +``` + +The `from_connection()` form is only available from the +`ModelRecordServiceSQL` class, and is used to manage records in a +previously-opened SQLITE3 database using a `sqlite3.connection` object +and a `threading.lock` object. It is intended for the specific use +case of storing the record information in the main InvokeAI database, +usually `databases/invokeai.db`. + +The `from_db_file()` methods can be used to open new connections to +the named database files. If the file doesn't exist, it will be +created and initialized. + +As a convenience, `ModelRecordServiceBase` offers two methods, +`from_db_file` and `open`, which will return either a SQL or File +implementation depending on the context. The former looks at the file +extension to determine whether to open the file as a SQL database +(".db") or as a file database (".yaml"). If the file exists, but is +either the wrong type or does not contain the expected schema +metainformation, then an appropriate `AssertionError` will be raised: + +``` +store = ModelRecordServiceBase.from_db_file('/path/to/a/file.{yaml,db}') +``` + +The `ModelRecordServiceBase.open()` method is specifically designed for use in the InvokeAI +web server and to maintain compatibility with earlier iterations of +the model manager. Its signature is: + +``` +def open( + cls, + config: InvokeAIAppConfig, + conn: Optional[sqlite3.Connection] = None, + lock: Optional[threading.Lock] = None + ) -> Union[ModelRecordServiceSQL, ModelRecordServiceFile]: +``` + +The way it works is as follows: + +1. Retrieve the value of the `model_config_db` option from the user's + `invokeai.yaml` config file. +2. If `model_config_db` is `auto` (the default), then: + - Use the values of `conn` and `lock` to return a `ModelRecordServiceSQL` object + opened on the passed connection and lock. + - Open up a new connection to `databases/invokeai.db` if `conn` + and/or `lock` are missing (see note below). +3. If `model_config_db` is a Path, then use `from_db_file` + to return the appropriate type of ModelRecordService. +4. If `model_config_db` is None, then retrieve the legacy + `conf_path` option from `invokeai.yaml` and use the Path + indicated there. This will default to `configs/models.yaml`. + +So a typical startup pattern would be: + +``` +import sqlite3 +from invokeai.app.services.thread import lock +from invokeai.app.services.model_record_service import ModelRecordServiceBase +from invokeai.app.services.config import InvokeAIAppConfig + +config = InvokeAIAppConfig.get_config() +db_conn = sqlite3.connect(config.db_path.as_posix(), check_same_thread=False) +store = ModelRecordServiceBase.open(config, db_conn, lock) +``` + +_A note on simultaneous access to `invokeai.db`_: The current InvokeAI +service architecture for the image and graph databases is careful to +use a shared sqlite3 connection and a thread lock to ensure that two +threads don't attempt to access the database simultaneously. However, +the default `sqlite3` library used by Python reports using +**Serialized** mode, which allows multiple threads to access the +database simultaneously using multiple database connections (see +https://www.sqlite.org/threadsafe.html and +https://ricardoanderegg.com/posts/python-sqlite-thread-safety/). Therefore +it should be safe to allow the record service to open its own SQLite +database connection. Opening a model record service should then be as +simple as `ModelRecordServiceBase.open(config)`. + +### Fetching a Model's Configuration from `ModelRecordServiceBase` + +Configurations can be retrieved in several ways. + +#### get_model(key) -> AnyModelConfig: + +The basic functionality is to call the record store object's +`get_model()` method with the desired model's unique key. It returns +the appropriate subclass of ModelConfigBase: + +``` +model_conf = store.get_model('f13dd932c0c35c22dcb8d6cda4203764') +print(model_conf.path) + +>> '/tmp/models/ckpts/v1-5-pruned-emaonly.safetensors' + +``` + +If the key is unrecognized, this call raises an +`UnknownModelException`. + +#### exists(key) -> AnyModelConfig: + +Returns True if a model with the given key exists in the databsae. + +#### search_by_path(path) -> AnyModelConfig: + +Returns the configuration of the model whose path is `path`. The path +is matched using a simple string comparison and won't correctly match +models referred to by different paths (e.g. using symbolic links). + +#### search_by_name(name, base, type) -> List[AnyModelConfig]: + +This method searches for models that match some combination of `name`, +`BaseType` and `ModelType`. Calling without any arguments will return +all the models in the database. + +#### all_models() -> List[AnyModelConfig]: + +Return all the model configs in the database. Exactly equivalent to +calling `search_by_name()` with no arguments. + +#### search_by_tag(tags) -> List[AnyModelConfig]: + +`tags` is a list of strings. This method returns a list of model +configs that contain all of the given tags. Examples: + +``` +# find all models that are marked as both SFW and as generating +# background scenery +configs = store.search_by_tag(['sfw', 'scenery']) +``` + +Note that only tags are not searchable in this way. Other fields can +be searched using a filter: + +``` +commercializable_models = [x for x in store.all_models() \ + if x.license.contains('allowCommercialUse=Sell')] +``` + +#### version() -> str: + +Returns the version of the database, currently at `3.2` + +#### model_info_by_name(name, base_model, model_type) -> ModelConfigBase: + +This method exists to ease the transition from the previous version of +the model manager, in which `get_model()` took the three arguments +shown above. This looks for a unique model identified by name, base +model and model type and returns it. + +The method will generate a `DuplicateModelException` if there are more +than one models that share the same type, base and name. While +unlikely, it is certainly possible to have a situation in which the +user had added two models with the same name, base and type, one +located at path `/foo/my_model` and the other at `/bar/my_model`. It +is strongly recommended to search for models using `search_by_name()`, +which can return multiple results, and then to select the desired +model and pass its ke to `get_model()`. + +### Writing model configs to the database + +Several methods allow you to create and update stored model config +records. + +#### add_model(key, config) -> ModelConfigBase: + +Given a key and a configuration, this will add the model's +configuration record to the database. `config` can either be a subclass of +`ModelConfigBase` (i.e. any class listed in `AnyModelConfig`), or a +`dict` of key/value pairs. In the latter case, the correct +configuration class will be picked by Pydantic's discriminated union +mechanism. + +If successful, the method will return the appropriate subclass of +`ModelConfigBase`. It will raise a `DuplicateModelException` if a +model with the same key is already in the database, or an +`InvalidModelConfigException` if a dict was passed and Pydantic +experienced a parse or validation error. + +### update_model(key, config) -> AnyModelConfig: + +Given a key and a configuration, this will update the model +configuration record in the database. `config` can be either a +instance of `ModelConfigBase`, or a sparse `dict` containing the +fields to be updated. This will return an `AnyModelConfig` on success, +or raise `InvalidModelConfigException` or `UnknownModelException` +exceptions on failure. + +***TO DO:*** Investigate why `update_model()` returns an +`AnyModelConfig` while `add_model()` returns a `ModelConfigBase`. + +### rename_model(key, new_name) -> ModelConfigBase: + +This is a special case of `update_model()` for the use case of +changing the model's name. It is broken out because there are cases in +which the InvokeAI application wants to synchronize the model's name +with its path in the `models` directory after changing the name, type +or base. However, when using the ModelRecordService directly, the call +is equivalent to: + +``` +store.rename_model(key, {'name': 'new_name'}) +``` + +***TO DO:*** Investigate why `rename_model()` is returning a +`ModelConfigBase` while `update_model()` returns a `AnyModelConfig`. + +*** + +## Let's get loaded, the lowdown on ModelLoadService + +The `ModelLoadService` is responsible for loading a named model into +memory so that it can be used for inference. Despite the fact that it +does a lot under the covers, it is very straightforward to use. + +An application-wide model loader is created at API initialization time +and stored in +`ApiDependencies.invoker.services.model_loader`. However, you can +create alternative instances if you wish. + +### Creating a ModelLoadService object + +The class is defined in +`invokeai.app.services.model_loader_service`. It is initialized with +an InvokeAIAppConfig object, from which it gets configuration +information such as the user's desired GPU and precision, and with a +previously-created `ModelRecordServiceBase` object, from which it +loads the requested model's configuration information. + +Here is a typical initialization pattern: + +``` +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.model_record_service import ModelRecordServiceBase +from invokeai.app.services.model_loader_service import ModelLoadService + +config = InvokeAIAppConfig.get_config() +store = ModelRecordServiceBase.open(config) +loader = ModelLoadService(config, store) +``` + +Note that we are relying on the contents of the application +configuration to choose the implementation of +`ModelRecordServiceBase`. + +### get_model(key, [submodel_type], [context]) -> ModelInfo: + +The `get_model()` method, like its similarly-named cousin in +`ModelRecordService`, receives the unique key that identifies the +model. It loads the model into memory, gets the model ready for use, +and returns a `ModelInfo` object. + +The optional second argument, `subtype` is a `SubModelType` string +enum, such as "vae". It is mandatory when used with a main model, and +is used to select which part of the main model to load. + +The optional third argument, `invocation_context` can be provided by +an invocation to trigger model load event reporting. See below for +details. + +The returned `ModelInfo` object shares some fields in common with +`ModelConfigBase`, but is otherwise a completely different beast: + +| **Field Name** | **Type** | **Description** | +|----------------|-----------------|------------------| +| `key` | str | The model key derived from the ModelRecordService database | +| `name` | str | Name of this model | +| `base_model` | BaseModelType | Base model for this model | +| `type` | ModelType or SubModelType | Either the model type (non-main) or the submodel type (main models)| +| `location` | Path or str | Location of the model on the filesystem | +| `precision` | torch.dtype | The torch.precision to use for inference | +| `context` | ModelCache.ModelLocker | A context class used to lock the model in VRAM while in use | + +The types for `ModelInfo` and `SubModelType` can be imported from +`invokeai.app.services.model_loader_service`. + +To use the model, you use the `ModelInfo` as a context manager using +the following pattern: + +``` +model_info = loader.get_model('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae')) +with model_info as vae: + image = vae.decode(latents)[0] +``` + +The `vae` model will stay locked in the GPU during the period of time +it is in the context manager's scope. + +`get_model()` may raise any of the following exceptions: + +- `UnknownModelException` -- key not in database +- `ModelNotFoundException` -- key in database but model not found at path +- `InvalidModelException` -- the model is guilty of a variety of sins + +** TO DO: ** Resolve discrepancy between ModelInfo.location and +ModelConfig.path. + +### Emitting model loading events + +When the `context` argument is passed to `get_model()`, it will +retrieve the invocation event bus from the passed `InvocationContext` +object to emit events on the invocation bus. The two events are +"model_load_started" and "model_load_completed". Both carry the +following payload: + +``` +payload=dict( + queue_id=queue_id, + queue_item_id=queue_item_id, + queue_batch_id=queue_batch_id, + graph_execution_state_id=graph_execution_state_id, + model_key=model_key, + submodel=submodel, + hash=model_info.hash, + location=str(model_info.location), + precision=str(model_info.precision), +) +``` + +*** + +## Get on line: The Download Queue + +InvokeAI can download arbitrary files using a multithreaded background +download queue. Internally, the download queue is used for installing +models located at remote locations. The queue is implemented by the +`DownloadQueueService` defined in +`invokeai.app.services.download_manager`. However, most of the +implementation is spread out among several files in +`invokeai/backend/model_manager/download/*` + +A default download queue is located in +`ApiDependencies.invoker.services.download_queue`. However, you can +create additional instances if you need to isolate your queue from the +main one. + +### A job for every task + +The queue operates on a series of download job objects. These objects +specify the source and destination of the download, and keep track of +the progress of the download. Jobs come in a variety of shapes and +colors as they are progressively specialized for particular download +task. + +The basic job is the `DownloadJobBase`, a pydantic object with the +following fields: + +| **Field** | **Type** | **Default** | **Description** | +|----------------|-----------------|---------------|-----------------| +| `id` | int | | Job ID, an integer >= 0 | +| `priority` | int | 10 | Job priority. Lower priorities run before higher priorities | +| `source` | Any | | Where to download from - untyped in base class | +| `destination` | Path | | Where to download to | +| `status` | DownloadJobStatus| Idle | Job's status (see below) | +| `event_handlers` | List[DownloadEventHandler]| | Event handlers (see below) | +| `job_started` | float | | Timestamp for when the job started running | +| `job_ended` | float | | Timestamp for when the job completed or errored out | +| `job_sequence` | int | | A counter that is incremented each time a model is dequeued | +| `preserve_partial_downloads`| bool | False | Resume partial downloads when relaunched. | +| `error` | Exception | | A copy of the Exception that caused an error during download | + +When you create a job, you can assign it a `priority`. If multiple +jobs are queued, the job with the lowest priority runs first. (Don't +blame me! The Unix developers came up with this convention.) + +Every job has a `source` and a `destination`. `source` is an untyped +object in the base class, but subclassses redefine it more +specifically. + +The `destination` must be the Path to a file or directory on the local +filesystem. If the Path points to a new or existing file, then the +source will be stored under that filename. If the Path ponts to an +existing directory, then the downloaded file will be stored inside the +directory, usually using the name assigned to it at the remote site in +the `content-disposition` http field. + +When the job is submitted, it is assigned a numeric `id`. The id can +then be used to control the job, such as starting, stopping and +cancelling its download. + +The `status` field is updated by the queue to indicate where the job +is in its lifecycle. Values are defined in the string enum +`DownloadJobStatus`, a symbol available from +`invokeai.app.services.download_manager`. Possible values are: + +| **Value** | **String Value** | ** Description ** | +|--------------|---------------------|-------------------| +| `IDLE` | idle | Job created, but not submitted to the queue | +| `ENQUEUED` | enqueued | Job is patiently waiting on the queue | +| `RUNNING` | running | Job is running! | +| `PAUSED` | paused | Job was paused and can be restarted | +| `COMPLETED` | completed | Job has finished its work without an error | +| `ERROR` | error | Job encountered an error and will not run again| +| `CANCELLED` | cancelled | Job was cancelled and will not run (again) | + +`job_started`, `job_ended` and `job_sequence` indicate when the job +was started (using a python timestamp), when it completed, and the +order in which it was taken off the queue. These are mostly used for +debugging and performance testing. + +In case of an error, the Exception that caused the error will be +placed in the `error` field, and the job's status will be set to +`DownloadJobStatus.ERROR`. + +After an error occurs, any partially downloaded files will be deleted +from disk, unless `preserve_partial_downloads` was set to True at job +creation time (or set to True any time before the error +occurred). Note that since most InvokeAI model install operations +involve downloading files to a temporary directory that has a limited +lifetime, this flag is not used by the model installer. + +There are a series of subclasses of `DownloadJobBase` that provide +support for specific types of downloads. These are: + +#### DownloadJobPath + +This subclass redefines `source` to be a filesystem Path. It is used +to move a file or directory from the `source` to the `destination` +paths in the background using a uniform event-based infrastructure. + +#### DownloadJobRemoteSource + +This subclass adds the following fields to the job: + +| **Field** | **Type** | **Default** | **Description** | +|----------------|-----------------|---------------|-----------------| +| `bytes` | int | 0 | bytes downloaded so far | +| `total_bytes` | int | 0 | total size to download | +| `access_token` | Any | None | an authorization token to present to the remote source | + +The job will start out with 0/0 in its bytes/total_bytes fields. Once +it starts running, `total_bytes` will be populated from information +provided in the HTTP download header (if available), and the number of +bytes downloaded so far will be progressively incremented. + +#### DownloadJobURL + +This is a subclass of `DownloadJobBase`. It redefines `source` to be a +Pydantic `AnyHttpUrl` object, which enforces URL validation checking +on the field. + +Note that the installer service defines an additional subclass of +`DownloadJobRemoteSource` that accepts HuggingFace repo_ids in +addition to URLs. This is discussed later in this document. + +### Event handlers + +While a job is being downloaded, the queue will emit events at +periodic intervals. A typical series of events during a successful +download session will look like this: + +- enqueued +- running +- running +- running +- completed + +There will be a single enqueued event, followed by one or more running +events, and finally one `completed`, `error` or `cancelled` +events. + +It is possible for a caller to pause download temporarily, in which +case the events may look something like this: + +- enqueued +- running +- running +- paused +- running +- completed + +The download queue logs when downloads start and end (unless `quiet` +is set to True at initialization time) but doesn't log any progress +events. You will probably want to be alerted to events during the +download job and provide more user feedback. In order to intercept and +respond to events you may install a series of one or more event +handlers in the job. Whenever the job's status changes, the chain of +event handlers is traversed and executed in the same thread that the +download job is running in. + +Event handlers have the signature `Callable[["DownloadJobBase"], +None]`, i.e. + +``` +def handler(job: DownloadJobBase): + pass +``` + +A typical handler will examine `job.status` and decide if there's +something to be done. This can include cancelling or erroring the job, +but more typically is used to report on the job status to the user +interface or to perform certain actions on successful completion of +the job. + +Event handlers can be attached to a job at creation time. In addition, +you can create a series of default handlers that are attached to the +queue object itself. These handlers will be executed for each job +after the job's own handlers (if any) have run. + +During a download, running events are issued every time roughly 1% of +the file is transferred. This is to provide just enough granularity to +update a tqdm progress bar smoothly. + +Handlers can be added to a job after the fact using the job's +`add_event_handler` method: + +``` +job.add_event_handler(my_handler) +``` + +All handlers can be cleared using the job's `clear_event_handlers()` +method. Note that it might be a good idea to pause the job before +altering its handlers. + +### Creating a download queue object + +The `DownloadQueueService` constructor takes the following arguments: + +| **Argument** | **Type** | **Default** | **Description** | +|----------------|-----------------|---------------|-----------------| +| `event_handlers` | List[DownloadEventHandler] | [] | Event handlers | +| `max_parallel_dl` | int | 5 | Maximum number of simultaneous downloads allowed | +| `requests_session` | requests.sessions.Session | None | An alternative requests Session object to use for the download | +| `quiet` | bool | False| Do work quietly without issuing log messages | + +A typical initialization sequence will look like: + +``` +from invokeai.app.services.download_manager import DownloadQueueService + +def log_download_event(job: DownloadJobBase): + logger.info(f'job={job.id}: status={job.status}') + +queue = DownloadQueueService( + event_handlers=[log_download_event] + ) +``` + +Event handlers can be provided to the queue at initialization time as +shown in the example. These will be automatically appended to the +handler list for any job that is submitted to this queue. + +`max_parallel_dl` sets the number of simultaneous active downloads +that are allowed. The default of five has not been benchmarked in any +way, but seems to give acceptable performance. + +`requests_session` can be used to provide a `requests` module Session +object that will be used to stream remote URLs to disk. This facility +was added for use in the module's unit tests to simulate a remote web +server, but may be useful in other contexts. + +`quiet` will prevent the queue from issuing any log messages at the +INFO or higher levels. + +### Submitting a download job + +You can submit a download job to the queue either by creating the job +manually and passing it to the queue's `submit_download_job()` method, +or using the `create_download_job()` method, which will do the same +thing on your behalf. + +To use the former method, follow this example: + +``` +job = DownloadJobRemoteSource( + source='http://www.civitai.com/models/13456', + destination='/tmp/models/', + event_handlers=[my_handler1, my_handler2], # if desired + ) +queue.submit_download_job(job, start=True) +``` + +`submit_download_job()` takes just two arguments: the job to submit, +and a flag indicating whether to immediately start the job (defaulting +to True). If you choose not to start the job immediately, you can +start it later by calling the queue's `start_job()` or +`start_all_jobs()` methods, which are described later. + +To have the queue create the job for you, follow this example instead: + +``` +job = queue.create_download_job( + source='http://www.civitai.com/models/13456', + destdir='/tmp/models/', + filename='my_model.safetensors', + event_handlers=[my_handler1, my_handler2], # if desired + start=True, + ) + ``` + +The `filename` argument forces the downloader to use the specified +name for the file rather than the name provided by the remote source, +and is equivalent to manually specifying a destination of +`/tmp/models/my_model.safetensors' in the submitted job. + +Here is the full list of arguments that can be provided to +`create_download_job()`: + + +| **Argument** | **Type** | **Default** | **Description** | +|------------------|------------------------------|-------------|-------------------------------------------| +| `source` | Union[str, Path, AnyHttpUrl] | | Download remote or local source | +| `destdir` | Path | | Destination directory for downloaded file | +| `filename` | Path | None | Filename for downloaded file | +| `start` | bool | True | Enqueue the job immediately | +| `priority` | int | 10 | Starting priority for this job | +| `access_token` | str | None | Authorization token for this resource | +| `event_handlers` | List[DownloadEventHandler] | [] | Event handlers for this job | + +Internally, `create_download_job()` has a little bit of internal logic +that looks at the type of the source and selects the right subclass of +`DownloadJobBase` to create and enqueue. + +**TODO**: move this logic into its own method for overriding in +subclasses. + +### Job control + +Prior to completion, jobs can be controlled with a series of queue +method calls. Do not attempt to modify jobs by directly writing to +their fields, as this is likely to lead to unexpected results. + +Any method that accepts a job argument may raise an +`UnknownJobIDException` if the job has not yet been submitted to the +queue or was not created by this queue. + +#### queue.join() + +This method will block until all the active jobs in the queue have +reached a terminal state (completed, errored or cancelled). + +#### jobs = queue.list_jobs() + +This will return a list of all jobs, including ones that have not yet +been enqueued and those that have completed or errored out. + +#### job = queue.id_to_job(int) + +This method allows you to recover a submitted job using its ID. + +#### queue.prune_jobs() + +Remove completed and errored jobs from the job list. + +#### queue.start_job(job) + +If the job was submitted with `start=False`, then it can be started +using this method. + +#### queue.pause_job(job) + +This will temporarily pause the job, if possible. It can later be +restarted and pick up where it left off using `queue.start_job()`. + +#### queue.cancel_job(job) + +This will cancel the job if possible and clean up temporary files and +other resources that it might have been using. + +#### queue.start_all_jobs(), queue.pause_all_jobs(), queue.cancel_all_jobs() + +This will start/pause/cancel all jobs that have been submitted to the +queue and have not yet reached a terminal state. + +## Model installation + +The `ModelInstallService` class implements the +`ModelInstallServiceBase` abstract base class, and provides a one-stop +shop for all your model install needs. It provides the following +functionality: + +- Registering a model config record for a model already located on the + local filesystem, without moving it or changing its path. + +- Installing a model alreadiy located on the local filesystem, by + moving it into the InvokeAI root directory under the + `models` folder (or wherever config parameter `models_dir` + specifies). + +- Downloading a model from an arbitrary URL and installing it in + `models_dir`. + +- Special handling for Civitai model URLs which allow the user to + paste in a model page's URL or download link. Any metadata provided + by Civitai, such as trigger terms, are captured and placed in the + model config record. + +- Special handling for HuggingFace repo_ids to recursively download + the contents of the repository, paying attention to alternative + variants such as fp16. + +- Probing of models to determine their type, base type and other key + information. + +- Interface with the InvokeAI event bus to provide status updates on + the download, installation and registration process. + +### Initializing the installer + +A default installer is created at InvokeAI api startup time and stored +in `ApiDependencies.invoker.services.model_install_service` and can +also be retrieved from an invocation's `context` argument with +`context.services.model_install_service`. + +In the event you wish to create a new installer, you may use the +following initialization pattern: + +``` +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.download_manager import DownloadQueueServive +from invokeai.app.services.model_record_service import ModelRecordServiceBase + +config = InvokeAI.get_config() +queue = DownloadQueueService() +store = ModelRecordServiceBase.open(config) +installer = ModelInstallService(config=config, queue=queue, store=store) +``` + +The full form of `ModelInstallService()` takes the following +parameters. Each parameter will default to a reasonable value, but it +is recommended that you set them explicitly as shown in the above example. + +| **Argument** | **Type** | **Default** | **Description** | +|------------------|------------------------------|-------------|-------------------------------------------| +| `config` | InvokeAIAppConfig | Use system-wide config | InvokeAI app configuration object | +| `queue` | DownloadQueueServiceBase | Create a new download queue for internal use | Download queue | +| `store` | ModelRecordServiceBase | Use config to select the database to open | Config storage database | +| `event_bus` | EventServiceBase | None | An event bus to send download/install progress events to | +| `event_handlers` | List[DownloadEventHandler] | None | Event handlers for the download queue | + +Note that if `store` is not provided, then the class will use +`ModelRecordServiceBase.open(config)` to select the database to use. + +Once initialized, the installer will provide the following methods: + +#### install_job = installer.install_model() + +The `install_model()` method is the core of the installer. The +following illustrates basic usage: + +``` +sources = [ + Path('/opt/models/sushi.safetensors'), # a local safetensors file + Path('/opt/models/sushi_diffusers/'), # a local diffusers folder + 'runwayml/stable-diffusion-v1-5', # a repo_id + 'runwayml/stable-diffusion-v1-5:vae', # a subfolder within a repo_id + 'https://civitai.com/api/download/models/63006', # a civitai direct download link + 'https://civitai.com/models/8765?modelVersionId=10638', # civitai model page + 'https://s3.amazon.com/fjacks/sd-3.safetensors', # arbitrary URL +] + +for source in sources: + install_job = installer.install_model(source) + +source2key = installer.wait_for_installs() +for source in sources: + model_key = source2key[source] + print(f"{source} installed as {model_key}") +``` + +As shown here, the `install_model()` method accepts a variety of +sources, including local safetensors files, local diffusers folders, +HuggingFace repo_ids with and without a subfolder designation, +Civitai model URLs and arbitrary URLs that point to checkpoint files +(but not to folders). + +Each call to `install_model()` will return a `ModelInstallJob` job, a +subclass of `DownloadJobBase`. The install job has additional +install-specific fields described in the next section. + +Each install job will run in a series of background threads using +the object's download queue. You may block until all install jobs are +completed (or errored) by calling the `wait_for_installs()` method as +shown in the code example. `wait_for_installs()` will return a `dict` +that maps the requested source to the key of the installed model. In +the case that a model fails to download or install, its value in the +dict will be None. The actual cause of the error will be reported in +the corresponding job's `error` field. + +Alternatively you may install event handlers and/or listen for events +on the InvokeAI event bus in order to monitor the progress of the +requested installs. + +The full list of arguments to `model_install()` is as follows: + +| **Argument** | **Type** | **Default** | **Description** | +|------------------|------------------------------|-------------|-------------------------------------------| +| `source` | Union[str, Path, AnyHttpUrl] | | The source of the model, Path, URL or repo_id | +| `inplace` | bool | True | Leave a local model in its current location | +| `variant` | str | None | Desired variant, such as 'fp16' or 'onnx' (HuggingFace only) | +| `subfolder` | str | None | Repository subfolder (HuggingFace only) | +| `probe_override` | Dict[str, Any] | None | Override all or a portion of model's probed attributes | +| `metadata` | ModelSourceMetadata | None | Provide metadata that will be added to model's config | +| `access_token` | str | None | Provide authorization information needed to download | +| `priority` | int | 10 | Download queue priority for the job | + + +The `inplace` field controls how local model Paths are handled. If +True (the default), then the model is simply registered in its current +location by the installer's `ModelConfigRecordService`. Otherwise, the +model will be moved into the location specified by the `models_dir` +application configuration parameter. + +The `variant` field is used for HuggingFace repo_ids only. If +provided, the repo_id download handler will look for and download +tensors files that follow the convention for the selected variant: + +- "fp16" will select files named "*model.fp16.{safetensors,bin}" +- "onnx" will select files ending with the suffix ".onnx" +- "openvino" will select files beginning with "openvino_model" + +In the special case of the "fp16" variant, the installer will select +the 32-bit version of the files if the 16-bit version is unavailable. + +`subfolder` is used for HuggingFace repo_ids only. If provided, the +model will be downloaded from the designated subfolder rather than the +top-level repository folder. If a subfolder is attached to the repo_id +using the format `repo_owner/repo_name:subfolder`, then the subfolder +specified by the repo_id will override the subfolder argument. + +`probe_override` can be used to override all or a portion of the +attributes returned by the model prober. This can be used to overcome +cases in which automatic probing is unable to (correctly) determine +the model's attribute. The most common situation is the +`prediction_type` field for sd-2 (and rare sd-1) models. Here is an +example of how it works: + +``` +install_job = installer.install_model( + source='stabilityai/stable-diffusion-2-1', + variant='fp16', + probe_override=dict( + prediction_type=SchedulerPredictionType('v_prediction') + ) + ) +``` + +`metadata` allows you to attach custom metadata to the installed +model. See the next section for details. + +`priority` and `access_token` are passed to the download queue and +have the same effect as they do for the DownloadQueueServiceBase. + +#### Monitoring the install job process + +When you create an install job with `model_install()`, events will be +passed to the list of `DownloadEventHandlers` provided at installer +initialization time. Event handlers can also be added to individual +model install jobs by calling their `add_handler()` method as +described earlier for the `DownloadQueueService`. + +If the `event_bus` argument was provided, events will also be +broadcast to the InvokeAI event bus. The events will appear on the bus +as a singular event type named `model_event` with a payload of +`job`. You can then retrieve the job and check its status. + +** TO DO: ** consider breaking `model_event` into +`model_install_started`, `model_install_completed`, etc. The event bus +features have not yet been tested with FastAPI/websockets, and it may +turn out that the job object is not serializable. + +#### Model metadata and probing + +The install service has special handling for HuggingFace and Civitai +URLs that capture metadata from the source and include it in the model +configuration record. For example, fetching the Civitai model 8765 +will produce a config record similar to this (using YAML +representation): + +``` +5abc3ef8600b6c1cc058480eaae3091e: + path: sd-1/lora/to8contrast-1-5.safetensors + name: to8contrast-1-5 + base_model: sd-1 + model_type: lora + model_format: lycoris + key: 5abc3ef8600b6c1cc058480eaae3091e + hash: 5abc3ef8600b6c1cc058480eaae3091e + description: 'Trigger terms: to8contrast style' + author: theovercomer8 + license: allowCommercialUse=Sell; allowDerivatives=True; allowNoCredit=True + source: https://civitai.com/models/8765?modelVersionId=10638 + thumbnail_url: null + tags: + - model + - style + - portraits +``` + +For sources that do not provide model metadata, you can attach custom +fields by providing a `metadata` argument to `model_install()` using +an initialized `ModelSourceMetadata` object (available for import from +`model_install_service.py`): + +``` +from invokeai.app.services.model_install_service import ModelSourceMetadata +meta = ModelSourceMetadata( + name="my model", + author="Sushi Chef", + description="Highly customized model; trigger with 'sushi'," + license="mit", + thumbnail_url="http://s3.amazon.com/ljack/pics/sushi.png", + tags=list('sfw', 'food') + ) +install_job = installer.install_model( + source='sushi_chef/model3', + variant='fp16', + metadata=meta, + ) +``` + +It is not currently recommended to provide custom metadata when +installing from Civitai or HuggingFace source, as the metadata +provided by the source will overwrite the fields you provide. Instead, +after the model is installed you can use +`ModelRecordService.update_model()` to change the desired fields. + +** TO DO: ** Change the logic so that the caller's metadata fields take +precedence over those provided by the source. + + +#### Other installer methods + +This section describes additional, less-frequently-used attributes and +methods provided by the installer class. + +##### installer.wait_for_installs() + +This is equivalent to the `DownloadQueue` `join()` method. It will +block until all the active jobs in the install queue have reached a +terminal state (completed, errored or cancelled). + +##### installer.queue, installer.store, installer.config + +These attributes provide access to the `DownloadQueueServiceBase`, +`ModelConfigRecordServiceBase`, and `InvokeAIAppConfig` objects that +the installer uses. + +For example, to temporarily pause all pending installations, you can +do this: + +``` +installer.queue.pause_all_jobs() +``` +##### key = installer.register_path(model_path, overrides), key = installer.install_path(model_path, overrides) + +These methods bypass the download queue and directly register or +install the model at the indicated path, returning the unique ID for +the installed model. + +Both methods accept a Path object corresponding to a checkpoint or +diffusers folder, and an optional dict of attributes to use to +override the values derived from model probing. + +The difference between `register_path()` and `install_path()` is that +the former will not move the model from its current position, while +the latter will move it into the `models_dir` hierarchy. + +##### installer.unregister(key) + +This will remove the model config record for the model at key, and is +equivalent to `installer.store.unregister(key)` + +##### installer.delete(key) + +This is similar to `unregister()` but has the additional effect of +deleting the underlying model file(s) -- even if they were outside the +`models_dir` directory! + +##### installer.conditionally_delete(key) + +This method will call `unregister()` if the model identified by `key` +is outside the `models_dir` hierarchy, and call `delete()` if the +model is inside. + +#### List[str]=installer.scan_directory(scan_dir: Path, install: bool) + +This method will recursively scan the directory indicated in +`scan_dir` for new models and either install them in the models +directory or register them in place, depending on the setting of +`install` (default False). + +The return value is the list of keys of the new installed/registered +models. + +#### installer.scan_models_directory() + +This method scans the models directory for new models and registers +them in place. Models that are present in the +`ModelConfigRecordService` database whose paths are not found will be +unregistered. + +#### installer.sync_to_config() + +This method synchronizes models in the models directory and autoimport +directory to those in the `ModelConfigRecordService` database. New +models are registered and orphan models are unregistered. + +#### hash=installer.hash(model_path) + +This method is calls the fasthash algorithm on a model's Path +(either a file or a folder) to generate a unique ID based on the +contents of the model. + +##### installer.start(invoker) + +The `start` method is called by the API intialization routines when +the API starts up. Its effect is to call `sync_to_config()` to +synchronize the model record store database with what's currently on +disk. + +This method should not ordinarily be called manually. diff --git a/docs/contributing/contribution_guides/development.md b/docs/contributing/contribution_guides/development.md index 086fd6e90d0..6f7f6cc8b6b 100644 --- a/docs/contributing/contribution_guides/development.md +++ b/docs/contributing/contribution_guides/development.md @@ -14,6 +14,7 @@ Once you're setup, for more information, you can review the documentation specif * #### [InvokeAI Architecure](../ARCHITECTURE.md) * #### [Frontend Documentation](./contributingToFrontend.md) * #### [Node Documentation](../INVOCATIONS.md) +* #### [InvokeAI Model Manager](../MODEL_MANAGER.md) * #### [Local Development](../LOCAL_DEVELOPMENT.md) diff --git a/docs/features/CONFIGURATION.md b/docs/features/CONFIGURATION.md index cfd65f8a613..d249e5331e2 100644 --- a/docs/features/CONFIGURATION.md +++ b/docs/features/CONFIGURATION.md @@ -207,11 +207,8 @@ if INVOKEAI_ROOT is `/home/fred/invokeai` and the path is | Setting | Default Value | Description | |----------|----------------|--------------| -| `autoimport_dir` | `autoimport/main` | At startup time, read and import any main model files found in this directory | -| `lora_dir` | `autoimport/lora` | At startup time, read and import any LoRA/LyCORIS models found in this directory | -| `embedding_dir` | `autoimport/embedding` | At startup time, read and import any textual inversion (embedding) models found in this directory | -| `controlnet_dir` | `autoimport/controlnet` | At startup time, read and import any ControlNet models found in this directory | -| `conf_path` | `configs/models.yaml` | Location of the `models.yaml` model configuration file | +| `autoimport_dir` | `autoimport/main` | At startup time, read and import any main model files found in this directory (not recommended)| +| `model_config_db` | `auto` | Location of the model configuration database. Specify `auto` to use the main invokeai.db database, or specify a `.yaml` or `.db` file to store the data externally.| | `models_dir` | `models` | Location of the directory containing models installed by InvokeAI's model manager | | `legacy_conf_dir` | `configs/stable-diffusion` | Location of the directory containing the .yaml configuration files for legacy checkpoint models | | `db_dir` | `databases` | Location of the directory containing InvokeAI's image, schema and session database | @@ -234,6 +231,18 @@ Paths: # controlnet_dir: null ``` +### Model Cache + +These options control the size of various caches that InvokeAI uses +during the model loading and conversion process. All units are in GB + +| Setting | Default Value | Description | +|----------|----------------|--------------| +| `disk` | `20.0` | Before loading a model into memory, InvokeAI converts .ckpt and .safetensors models into diffusers format and saves them to disk. This option controls the maximum size of the directory in which these converted models are stored. If set to zero, then only the most recently-used model will be cached. | +| `ram` | `6.0` | After loading a model from disk, it is kept in system RAM until it is needed again. This option controls how much RAM is set aside for this purpose. Larger amounts allow more models to reside in RAM and for InvokeAI to quickly switch between them. | +| `vram` | `0.25` | This allows smaller models to remain in VRAM, speeding up execution modestly. It should be a small number. | + + ### Logging These settings control the information, warning, and debugging diff --git a/docs/installation/050_INSTALLING_MODELS.md b/docs/installation/050_INSTALLING_MODELS.md index d455d2146f3..e2a500512cb 100644 --- a/docs/installation/050_INSTALLING_MODELS.md +++ b/docs/installation/050_INSTALLING_MODELS.md @@ -123,11 +123,20 @@ installation. Examples: # (list all controlnet models) invokeai-model-install --list controlnet -# (install the model at the indicated URL) +# (install the diffusers model using its hugging face repo_id) +invokeai-model-install --add stabilityai/stable-diffusion-xl-base-1.0 + +# (install a diffusers model that lives in a subfolder) +invokeai-model-install --add stabilityai/stable-diffusion-xl-base-1.0:vae + +# (install the checkpoint model at the indicated URL) invokeai-model-install --add https://civitai.com/api/download/models/128713 -# (delete the named model) -invokeai-model-install --delete sd-1/main/analog-diffusion +# (delete the named model if its name is unique) +invokeai-model-install --delete analog-diffusion + +# (delete the named model using its fully qualified name) +invokeai-model-install --delete sd-1/main/test_model ``` ### Installation via the Web GUI @@ -141,6 +150,24 @@ left-hand panel) and navigate to *Import Models* wish to install. You may use a URL, HuggingFace repo id, or a path on your local disk. +There is special scanning for CivitAI URLs which lets +you cut-and-paste either the URL for a CivitAI model page +(e.g. https://civitai.com/models/12345), or the direct download link +for a model (e.g. https://civitai.com/api/download/models/12345). + +If the desired model is a HuggingFace diffusers model that is located +in a subfolder of the repository (e.g. vae), then append the subfolder +to the end of the repo_id like this: + +``` +# a VAE model located in subfolder "vae" +stabilityai/stable-diffusion-xl-base-1.0:vae + +# version 2 of the model located in subfolder "v2" +monster-labs/control_v1p_sd15_qrcode_monster:v2 + +``` + 3. Alternatively, the *Scan for Models* button allows you to paste in the path to a folder somewhere on your machine. It will be scanned for importable models and prompt you to add the ones of your choice. diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 9db35fb5c3d..a05d6d0d347 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -19,6 +19,7 @@ from invokeai.version.invokeai_version import __version__ from ..services.default_graphs import create_system_graphs +from ..services.download_manager import DownloadQueueService from ..services.graph import GraphExecutionState, LibraryGraph from ..services.image_file_storage import DiskImageFileStorage from ..services.invocation_queue import MemoryInvocationQueue @@ -26,7 +27,9 @@ from ..services.invocation_stats import InvocationStatsService from ..services.invoker import Invoker from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage -from ..services.model_manager_service import ModelManagerService +from ..services.model_install_service import ModelInstallService +from ..services.model_loader_service import ModelLoadService +from ..services.model_record_service import ModelRecordServiceBase from ..services.processor import DefaultInvocationProcessor from ..services.sqlite import SqliteItemStorage from ..services.thread import lock @@ -127,8 +130,12 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger ) ) + download_queue = DownloadQueueService(event_bus=events) + model_record_store = ModelRecordServiceBase.open(config, conn=db_conn, lock=lock) + model_loader = ModelLoadService(config, model_record_store) + model_installer = ModelInstallService(config, queue=download_queue, store=model_record_store, event_bus=events) + services = InvocationServices( - model_manager=ModelManagerService(config, logger), events=events, latents=latents, images=images, @@ -141,6 +148,10 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger configuration=config, performance_statistics=InvocationStatsService(graph_execution_manager), logger=logger, + download_queue=download_queue, + model_record_store=model_record_store, + model_loader=model_loader, + model_installer=model_installer, session_queue=SqliteSessionQueue(conn=db_conn, lock=lock), session_processor=DefaultSessionProcessor(), invocation_cache=MemoryInvocationCache(max_cache_size=config.node_cache_size), diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index ebc40f5ce52..12ffc7e916a 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -2,35 +2,60 @@ import pathlib -from typing import List, Literal, Optional, Union +from enum import Enum +from typing import Any, List, Literal, Optional, Union from fastapi import Body, Path, Query, Response from fastapi.routing import APIRouter from pydantic import BaseModel, parse_obj_as from starlette.exceptions import HTTPException +from invokeai.app.api.dependencies import ApiDependencies +from invokeai.app.services.download_manager import DownloadJobRemoteSource, DownloadJobStatus, UnknownJobIDException +from invokeai.app.services.model_convert import MergeInterpolationMethod, ModelConvert +from invokeai.app.services.model_install_service import ModelInstallJob from invokeai.backend import BaseModelType, ModelType -from invokeai.backend.model_management import MergeInterpolationMethod -from invokeai.backend.model_management.models import ( +from invokeai.backend.model_manager import ( OPENAPI_MODEL_CONFIGS, + DuplicateModelException, InvalidModelException, - ModelNotFoundException, + ModelConfigBase, + ModelSearch, SchedulerPredictionType, + UnknownModelException, ) -from ..dependencies import ApiDependencies - models_router = APIRouter(prefix="/v1/models", tags=["models"]) -UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] -ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] -ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] -MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] -ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)] +# NOTE: The generic configuration classes defined in invokeai.backend.model_manager.config +# such as "MainCheckpointConfig" are repackaged by code originally written by Stalker +# into base-specific classes such as `abc.StableDiffusion1ModelCheckpointConfig` +# This is the reason for the calls to dict() followed by pydantic.parse_obj_as() + +# There are still numerous mypy errors here because it does not seem to like this +# way of dynamically generating the typing hints below. +InvokeAIModelConfig: Any = Union[tuple(OPENAPI_MODEL_CONFIGS)] class ModelsList(BaseModel): - models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]] + models: List[InvokeAIModelConfig] + + +class ModelDownloadStatus(BaseModel): + """Return information about a background installation job.""" + + job_id: int + source: str + priority: int + bytes: int + total_bytes: int + status: DownloadJobStatus + + +class JobControlOperation(str, Enum): + START = "Start" + PAUSE = "Pause" + CANCEL = "Cancel" @models_router.get( @@ -42,19 +67,22 @@ async def list_models( base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"), model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"), ) -> ModelsList: - """Gets a list of models""" + """Get a list of models.""" + record_store = ApiDependencies.invoker.services.model_record_store if base_models and len(base_models) > 0: models_raw = list() for base_model in base_models: - models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type)) + models_raw.extend( + [x.dict() for x in record_store.search_by_name(base_model=base_model, model_type=model_type)] + ) else: - models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type) + models_raw = [x.dict() for x in record_store.search_by_name(model_type=model_type)] models = parse_obj_as(ModelsList, {"models": models_raw}) return models @models_router.patch( - "/{base_model}/{model_type}/{model_name}", + "/i/{key}", operation_id="update_model", responses={ 200: {"description": "The model was updated successfully"}, @@ -63,69 +91,36 @@ async def list_models( 409: {"description": "There is already a model corresponding to the new name"}, }, status_code=200, - response_model=UpdateModelResponse, + response_model=InvokeAIModelConfig, ) async def update_model( - base_model: BaseModelType = Path(description="Base model"), - model_type: ModelType = Path(description="The type of model"), - model_name: str = Path(description="model name"), - info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"), -) -> UpdateModelResponse: + key: str = Path(description="Unique key of model"), + info: InvokeAIModelConfig = Body(description="Model configuration"), +) -> InvokeAIModelConfig: """Update model contents with a new config. If the model name or base fields are changed, then the model is renamed.""" logger = ApiDependencies.invoker.services.logger - + info_dict = info.dict() + record_store = ApiDependencies.invoker.services.model_record_store + model_install = ApiDependencies.invoker.services.model_installer try: - previous_info = ApiDependencies.invoker.services.model_manager.list_model( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ) - - # rename operation requested - if info.model_name != model_name or info.base_model != base_model: - ApiDependencies.invoker.services.model_manager.rename_model( - base_model=base_model, - model_type=model_type, - model_name=model_name, - new_name=info.model_name, - new_base=info.base_model, - ) - logger.info(f"Successfully renamed {base_model.value}/{model_name}=>{info.base_model}/{info.model_name}") - # update information to support an update of attributes - model_name = info.model_name - base_model = info.base_model - new_info = ApiDependencies.invoker.services.model_manager.list_model( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ) - if new_info.get("path") != previous_info.get( - "path" - ): # model manager moved model path during rename - don't overwrite it - info.path = new_info.get("path") - - # replace empty string values with None/null to avoid phenomenon of vae: '' - info_dict = info.dict() - info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()} - - ApiDependencies.invoker.services.model_manager.update_model( - model_name=model_name, base_model=base_model, model_type=model_type, model_attributes=info_dict - ) - - model_raw = ApiDependencies.invoker.services.model_manager.list_model( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ) - model_response = parse_obj_as(UpdateModelResponse, model_raw) - except ModelNotFoundException as e: + new_config = record_store.update_model(key, config=info_dict) + except UnknownModelException as e: raise HTTPException(status_code=404, detail=str(e)) except ValueError as e: logger.error(str(e)) raise HTTPException(status_code=409, detail=str(e)) - except Exception as e: + + try: + # In the event that the model's name, type or base has changed, and the model itself + # resides in the invokeai root models directory, then the next statement will move + # the model file into its new canonical location. + new_config = model_install.sync_model_path(new_config.key) + model_response = parse_obj_as(InvokeAIModelConfig, new_config.dict()) + except UnknownModelException as e: + raise HTTPException(status_code=404, detail=str(e)) + except ValueError as e: logger.error(str(e)) - raise HTTPException(status_code=400, detail=str(e)) + raise HTTPException(status_code=409, detail=str(e)) return model_response @@ -141,7 +136,7 @@ async def update_model( 409: {"description": "There is already a model corresponding to this path or repo_id"}, }, status_code=201, - response_model=ImportModelResponse, + response_model=ModelDownloadStatus, ) async def import_model( location: str = Body(description="A model path, repo_id or URL to import"), @@ -149,30 +144,47 @@ async def import_model( description="Prediction type for SDv2 checkpoints and rare SDv1 checkpoints", default=None, ), -) -> ImportModelResponse: - """Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically""" + priority: Optional[int] = Body( + description="Which import jobs run first. Lower values run before higher ones.", + default=10, + ), +) -> ModelDownloadStatus: + """ + Add a model using its local path, repo_id, or remote URL. + + Models will be downloaded, probed, configured and installed in a + series of background threads. The return object has a `job_id` property + that can be used to control the download job. + + The priority controls which import jobs run first. Lower values run before + higher ones. + + The prediction_type applies to SDv2 models only and can be one of + "v_prediction", "epsilon", or "sample". Default if not provided is + "v_prediction". - items_to_import = {location} - prediction_types = {x.value: x for x in SchedulerPredictionType} + Listen on the event bus for a series of `model_event` events with an `id` + matching the returned job id to get the progress, completion status, errors, + and information on the model that was installed. + """ logger = ApiDependencies.invoker.services.logger try: - installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import( - items_to_import=items_to_import, prediction_type_helper=lambda x: prediction_types.get(prediction_type) + installer = ApiDependencies.invoker.services.model_installer + result = installer.install_model( + location, + probe_override={"prediction_type": SchedulerPredictionType(prediction_type) if prediction_type else None}, + priority=priority, ) - info = installed_models.get(location) - - if not info: - logger.error("Import failed") - raise HTTPException(status_code=415) - - logger.info(f"Successfully imported {location}, got {info}") - model_raw = ApiDependencies.invoker.services.model_manager.list_model( - model_name=info.name, base_model=info.base_model, model_type=info.model_type + return ModelDownloadStatus( + job_id=result.id, + source=result.source, + priority=result.priority, + bytes=result.bytes, + total_bytes=result.total_bytes, + status=result.status, ) - return parse_obj_as(ImportModelResponse, model_raw) - - except ModelNotFoundException as e: + except UnknownModelException as e: logger.error(str(e)) raise HTTPException(status_code=404, detail=str(e)) except InvalidModelException as e: @@ -189,29 +201,40 @@ async def import_model( responses={ 201: {"description": "The model added successfully"}, 404: {"description": "The model could not be found"}, - 424: {"description": "The model appeared to add successfully, but could not be found in the model manager"}, 409: {"description": "There is already a model corresponding to this path or repo_id"}, + 415: {"description": "Unrecognized file/folder format"}, }, status_code=201, - response_model=ImportModelResponse, + response_model=InvokeAIModelConfig, ) async def add_model( - info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"), -) -> ImportModelResponse: - """Add a model using the configuration information appropriate for its type. Only local models can be added by path""" + info: InvokeAIModelConfig = Body(description="Model configuration"), +) -> InvokeAIModelConfig: + """ + Add a model using the configuration information appropriate for its type. Only local models can be added by path. + This call will block until the model is installed. + """ logger = ApiDependencies.invoker.services.logger + path = info.path + installer = ApiDependencies.invoker.services.model_installer + record_store = ApiDependencies.invoker.services.model_record_store + try: + key = installer.install_path(path) + logger.info(f"Created model {key} for {path}") + except DuplicateModelException as e: + logger.error(str(e)) + raise HTTPException(status_code=409, detail=str(e)) + except InvalidModelException as e: + logger.error(str(e)) + raise HTTPException(status_code=415) + # update with the provided info try: - ApiDependencies.invoker.services.model_manager.add_model( - info.model_name, info.base_model, info.model_type, model_attributes=info.dict() - ) - logger.info(f"Successfully added {info.model_name}") - model_raw = ApiDependencies.invoker.services.model_manager.list_model( - model_name=info.model_name, base_model=info.base_model, model_type=info.model_type - ) - return parse_obj_as(ImportModelResponse, model_raw) - except ModelNotFoundException as e: + info_dict = info.dict() + new_config = record_store.update_model(key, new_config=info_dict) + return parse_obj_as(InvokeAIModelConfig, new_config.dict()) + except UnknownModelException as e: logger.error(str(e)) raise HTTPException(status_code=404, detail=str(e)) except ValueError as e: @@ -220,33 +243,34 @@ async def add_model( @models_router.delete( - "/{base_model}/{model_type}/{model_name}", + "/i/{key}", operation_id="del_model", responses={204: {"description": "Model deleted successfully"}, 404: {"description": "Model not found"}}, status_code=204, response_model=None, ) async def delete_model( - base_model: BaseModelType = Path(description="Base model"), - model_type: ModelType = Path(description="The type of model"), - model_name: str = Path(description="model name"), + key: str = Path(description="Unique key of model to remove from model registry."), + delete_files: Optional[bool] = Query(description="Delete underlying files and directories as well.", default=False), ) -> Response: """Delete Model""" logger = ApiDependencies.invoker.services.logger try: - ApiDependencies.invoker.services.model_manager.del_model( - model_name, base_model=base_model, model_type=model_type - ) - logger.info(f"Deleted model: {model_name}") + installer = ApiDependencies.invoker.services.model_installer + if delete_files: + installer.delete(key) + else: + installer.unregister(key) + logger.info(f"Deleted model: {key}") return Response(status_code=204) - except ModelNotFoundException as e: + except UnknownModelException as e: logger.error(str(e)) raise HTTPException(status_code=404, detail=str(e)) @models_router.put( - "/convert/{base_model}/{model_type}/{model_name}", + "/convert/{key}", operation_id="convert_model", responses={ 200: {"description": "Model converted successfully"}, @@ -254,33 +278,26 @@ async def delete_model( 404: {"description": "Model not found"}, }, status_code=200, - response_model=ConvertModelResponse, + response_model=InvokeAIModelConfig, ) async def convert_model( - base_model: BaseModelType = Path(description="Base model"), - model_type: ModelType = Path(description="The type of model"), - model_name: str = Path(description="model name"), + key: str = Path(description="Unique key of model to convert from checkpoint/safetensors to diffusers format."), convert_dest_directory: Optional[str] = Query( default=None, description="Save the converted model to the designated directory" ), -) -> ConvertModelResponse: +) -> InvokeAIModelConfig: """Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none.""" - logger = ApiDependencies.invoker.services.logger try: - logger.info(f"Converting model: {model_name}") dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None - ApiDependencies.invoker.services.model_manager.convert_model( - model_name, - base_model=base_model, - model_type=model_type, - convert_dest_directory=dest, - ) - model_raw = ApiDependencies.invoker.services.model_manager.list_model( - model_name, base_model=base_model, model_type=model_type + converter = ModelConvert( + loader=ApiDependencies.invoker.services.model_loader, + installer=ApiDependencies.invoker.services.model_installer, + store=ApiDependencies.invoker.services.model_record_store, ) - response = parse_obj_as(ConvertModelResponse, model_raw) - except ModelNotFoundException as e: - raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}") + model_config = converter.convert_model(key, dest_directory=dest) + response = parse_obj_as(InvokeAIModelConfig, model_config.dict()) + except UnknownModelException as e: + raise HTTPException(status_code=404, detail=f"Model '{key}' not found: {str(e)}") except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) return response @@ -299,11 +316,12 @@ async def convert_model( async def search_for_models( search_path: pathlib.Path = Query(description="Directory path to search for models"), ) -> List[pathlib.Path]: + """Search for all models in a server-local path.""" if not search_path.is_dir(): raise HTTPException( status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory" ) - return ApiDependencies.invoker.services.model_manager.search_for_models(search_path) + return ModelSearch().search(search_path) @models_router.get( @@ -317,7 +335,10 @@ async def search_for_models( ) async def list_ckpt_configs() -> List[pathlib.Path]: """Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT.""" - return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs() + config = ApiDependencies.invoker.services.configuration + conf_path = config.legacy_conf_path + root_path = config.root_path + return [(conf_path / x).relative_to(root_path) for x in conf_path.glob("**/*.yaml")] @models_router.post( @@ -330,27 +351,32 @@ async def list_ckpt_configs() -> List[pathlib.Path]: response_model=bool, ) async def sync_to_config() -> bool: - """Call after making changes to models.yaml, autoimport directories or models directory to synchronize - in-memory data structures with disk data structures.""" - ApiDependencies.invoker.services.model_manager.sync_to_config() + """ + Synchronize model in-memory data structures with disk. + + Call after making changes to models.yaml, autoimport directories + or models directory. + """ + installer = ApiDependencies.invoker.services.model_installer + installer.sync_to_config() return True @models_router.put( - "/merge/{base_model}", + "/merge", operation_id="merge_models", responses={ 200: {"description": "Model converted successfully"}, 400: {"description": "Incompatible models"}, 404: {"description": "One or more models not found"}, + 409: {"description": "An identical merged model is already installed"}, }, status_code=200, - response_model=MergeModelResponse, + response_model=InvokeAIModelConfig, ) async def merge_models( - base_model: BaseModelType = Path(description="Base model"), - model_names: List[str] = Body(description="model name", min_items=2, max_items=3), - merged_model_name: Optional[str] = Body(description="Name of destination model"), + keys: List[str] = Body(description="model name", min_items=2, max_items=3), + merged_model_name: Optional[str] = Body(description="Name of destination model", default=None), alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5), interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"), force: Optional[bool] = Body( @@ -360,29 +386,147 @@ async def merge_models( description="Save the merged model to the designated directory (with 'merged_model_name' appended)", default=None, ), -) -> MergeModelResponse: - """Convert a checkpoint model into a diffusers model""" +) -> InvokeAIModelConfig: + """Merge the indicated diffusers model.""" logger = ApiDependencies.invoker.services.logger try: - logger.info(f"Merging models: {model_names} into {merge_dest_directory or ''}/{merged_model_name}") + logger.info(f"Merging models: {keys} into {merge_dest_directory or ''}/{merged_model_name}") dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None - result = ApiDependencies.invoker.services.model_manager.merge_models( - model_names, - base_model, - merged_model_name=merged_model_name or "+".join(model_names), + converter = ModelConvert( + loader=ApiDependencies.invoker.services.model_loader, + installer=ApiDependencies.invoker.services.model_installer, + store=ApiDependencies.invoker.services.model_record_store, + ) + result: ModelConfigBase = converter.merge_models( + model_keys=keys, + merged_model_name=merged_model_name, alpha=alpha, interp=interp, force=force, merge_dest_directory=dest, ) - model_raw = ApiDependencies.invoker.services.model_manager.list_model( - result.name, - base_model=base_model, - model_type=ModelType.Main, - ) - response = parse_obj_as(ConvertModelResponse, model_raw) - except ModelNotFoundException: - raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found") + response = parse_obj_as(InvokeAIModelConfig, result.dict()) + except DuplicateModelException as e: + raise HTTPException(status_code=409, detail=str(e)) + except UnknownModelException: + raise HTTPException(status_code=404, detail=f"One or more of the models '{keys}' not found") except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) return response + + +@models_router.get( + "/jobs", + operation_id="list_install_jobs", + responses={ + 200: {"description": "The control job was updated successfully"}, + 400: {"description": "Bad request"}, + }, + status_code=200, + response_model=List[ModelDownloadStatus], +) +async def list_install_jobs() -> List[ModelDownloadStatus]: + """List active and pending model installation jobs.""" + job_mgr = ApiDependencies.invoker.services.download_queue + jobs = job_mgr.list_jobs() + return [ + ModelDownloadStatus( + job_id=x.id, + source=x.source, + priority=x.priority, + bytes=x.bytes, + total_bytes=x.total_bytes, + status=x.status, + ) + for x in jobs + if isinstance(x, ModelInstallJob) + ] + + +@models_router.patch( + "/jobs/control/{operation}/{job_id}", + operation_id="control_download_jobs", + responses={ + 200: {"description": "The control job was updated successfully"}, + 400: {"description": "Bad request"}, + 404: {"description": "The job could not be found"}, + }, + status_code=200, + response_model=ModelDownloadStatus, +) +async def control_download_jobs( + job_id: int = Path(description="Download/install job_id for start, pause and cancel operations"), + operation: JobControlOperation = Path(description="The operation to perform on the job."), + priority_delta: Optional[int] = Body( + description="Change in job priority for priority operations only. Negative numbers increase priority.", + default=None, + ), +) -> ModelDownloadStatus: + """Start, pause, cancel, or change the run priority of a running model install job.""" + logger = ApiDependencies.invoker.services.logger + job_mgr = ApiDependencies.invoker.services.download_queue + try: + job = job_mgr.id_to_job(job_id) + + if operation == JobControlOperation.START: + job_mgr.start_job(job_id) + + elif operation == JobControlOperation.PAUSE: + job_mgr.pause_job(job_id) + + elif operation == JobControlOperation.CANCEL: + job_mgr.cancel_job(job_id) + + else: + raise ValueError("unknown operation {operation}") + bytes = 0 + total_bytes = 0 + if isinstance(job, DownloadJobRemoteSource): + bytes = job.bytes + total_bytes = job.total_bytes + + return ModelDownloadStatus( + job_id=job_id, + source=job.source, + priority=job.priority, + status=job.status, + bytes=bytes, + total_bytes=total_bytes, + ) + except UnknownJobIDException as e: + raise HTTPException(status_code=404, detail=str(e)) + except ValueError as e: + logger.error(str(e)) + raise HTTPException(status_code=409, detail=str(e)) + + +@models_router.patch( + "/jobs/cancel_all", + operation_id="cancel_all_download_jobs", + responses={ + 204: {"description": "All jobs cancelled successfully"}, + 400: {"description": "Bad request"}, + }, +) +async def cancel_all_download_jobs(): + """Cancel all model installation jobs.""" + logger = ApiDependencies.invoker.services.logger + job_mgr = ApiDependencies.invoker.services.download_queue + logger.info("Cancelling all download jobs.") + job_mgr.cancel_all_jobs() + return Response(status_code=204) + + +@models_router.patch( + "/jobs/prune", + operation_id="prune_jobs", + responses={ + 204: {"description": "All completed jobs have been pruned"}, + 400: {"description": "Bad request"}, + }, +) +async def prune_jobs(): + """Prune all completed and errored jobs.""" + mgr = ApiDependencies.invoker.services.download_queue + mgr.prune_jobs() + return Response(status_code=204) diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index fdbd64b30dd..709f1a3cf81 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -151,7 +151,7 @@ def custom_openapi(): invoker_schema["output"] = outputs_ref invoker_schema["class"] = "invocation" - from invokeai.backend.model_management.models import get_model_config_enums + from invokeai.backend.model_manager.models import get_model_config_enums for model_config_format_enum in set(get_model_config_enums()): name = model_config_format_enum.__qualname__ @@ -201,6 +201,10 @@ def overridden_redoc(): def invoke_api(): + if app_config.version: + print(f"InvokeAI version {__version__}") + return + def find_port(port: int): """Find a port not in use starting at given port""" # Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon! @@ -252,7 +256,4 @@ def find_port(port: int): if __name__ == "__main__": - if app_config.version: - print(f"InvokeAI version {__version__}") - else: - invoke_api() + invoke_api() diff --git a/invokeai/app/cli/completer.py b/invokeai/app/cli/completer.py index 5aece8a0584..6cd0db5fdb8 100644 --- a/invokeai/app/cli/completer.py +++ b/invokeai/app/cli/completer.py @@ -10,10 +10,11 @@ from typing import Dict, List, Literal, get_args, get_origin, get_type_hints import invokeai.backend.util.logging as logger +from invokeai.backend.model_manager import ModelType -from ...backend import ModelManager from ..invocations.baseinvocation import BaseInvocation from ..services.invocation_services import InvocationServices +from ..services.model_record_service import ModelRecordServiceBase from .commands import BaseCommand # singleton object, class variable @@ -21,11 +22,11 @@ class Completer(object): - def __init__(self, model_manager: ModelManager): + def __init__(self, model_record_store: ModelRecordServiceBase): self.commands = self.get_commands() self.matches = None self.linebuffer = None - self.manager = model_manager + self.store = model_record_store return def complete(self, text, state): @@ -127,7 +128,7 @@ def get_parameter_options(self, parameter: str, typehint) -> List[str]: if get_origin(typehint) == Literal: return get_args(typehint) if parameter == "model": - return self.manager.model_names() + return [x.name for x in self.store.model_info_by_name(model_type=ModelType.Main)] def _pre_input_hook(self): if self.linebuffer: @@ -142,7 +143,7 @@ def set_autocompleter(services: InvocationServices) -> Completer: if completer: return completer - completer = Completer(services.model_manager) + completer = Completer(services.model_record_store) readline.set_completer(completer.complete) try: diff --git a/invokeai/app/cli_app.py b/invokeai/app/cli_app.py index 2f8a4d2cbd7..5e203eefc00 100644 --- a/invokeai/app/cli_app.py +++ b/invokeai/app/cli_app.py @@ -30,6 +30,8 @@ from invokeai.app.services.images import ImageService, ImageServiceDependencies from invokeai.app.services.invocation_stats import InvocationStatsService from invokeai.app.services.resource_name import SimpleNameService + from invokeai.app.services.session_processor.session_processor_default import DefaultSessionProcessor + from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue from invokeai.app.services.urls import LocalUrlService from invokeai.backend.util.logging import InvokeAILogger from invokeai.version.invokeai_version import __version__ @@ -38,6 +40,7 @@ from .cli.completer import set_autocompleter from .invocations.baseinvocation import BaseInvocation from .services.default_graphs import create_system_graphs, default_text_to_image_graph_id + from .services.download_manager import DownloadQueueService from .services.events import EventServiceBase from .services.graph import ( Edge, @@ -52,9 +55,12 @@ from .services.invocation_services import InvocationServices from .services.invoker import Invoker from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage - from .services.model_manager_service import ModelManagerService + from .services.model_install_service import ModelInstallService + from .services.model_loader_service import ModelLoadService + from .services.model_record_service import ModelRecordServiceBase from .services.processor import DefaultInvocationProcessor from .services.sqlite import SqliteItemStorage + from .services.thread import lock if torch.backends.mps.is_available(): import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import) @@ -228,7 +234,12 @@ def invoke_all(context: CliContext): def invoke_cli(): + if config.version: + print(f"InvokeAI version {__version__}") + return + logger.info(f"InvokeAI version {__version__}") + # get the optional list of invocations to execute on the command line parser = config.get_parser() parser.add_argument("commands", nargs="*") @@ -239,8 +250,6 @@ def invoke_cli(): if infile := config.from_file: sys.stdin = open(infile, "r") - model_manager = ModelManagerService(config, logger) - events = EventServiceBase() output_folder = config.output_path @@ -254,15 +263,22 @@ def invoke_cli(): db_conn = sqlite3.connect(db_location, check_same_thread=False) # TODO: figure out a better threading solution logger.info(f'InvokeAI database location is "{db_location}"') - graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions") + download_queue = DownloadQueueService(event_bus=events) + model_record_store = ModelRecordServiceBase.open(config, conn=db_conn, lock=None) + model_loader = ModelLoadService(config, model_record_store) + model_installer = ModelInstallService(config, queue=download_queue, store=model_record_store, event_bus=events) + + graph_execution_manager = SqliteItemStorage[GraphExecutionState]( + conn=db_conn, table_name="graph_executions", lock=lock + ) urls = LocalUrlService() - image_record_storage = SqliteImageRecordStorage(conn=db_conn) + image_record_storage = SqliteImageRecordStorage(conn=db_conn, lock=lock) image_file_storage = DiskImageFileStorage(f"{output_folder}/images") names = SimpleNameService() - board_record_storage = SqliteBoardRecordStorage(conn=db_conn) - board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn) + board_record_storage = SqliteBoardRecordStorage(conn=db_conn, lock=lock) + board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn, lock=lock) boards = BoardService( services=BoardServiceDependencies( @@ -297,20 +313,25 @@ def invoke_cli(): ) services = InvocationServices( - model_manager=model_manager, events=events, latents=ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")), images=images, boards=boards, board_images=board_images, queue=MemoryInvocationQueue(), - graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"), + graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs", lock=lock), graph_execution_manager=graph_execution_manager, processor=DefaultInvocationProcessor(), performance_statistics=InvocationStatsService(graph_execution_manager), logger=logger, + download_queue=download_queue, + model_record_store=model_record_store, + model_loader=model_loader, + model_installer=model_installer, configuration=config, invocation_cache=MemoryInvocationCache(max_cache_size=config.node_cache_size), + session_queue=SqliteSessionQueue(conn=db_conn, lock=lock), + session_processor=DefaultSessionProcessor(), ) system_graphs = create_system_graphs(services.graph_library) @@ -478,7 +499,4 @@ def invoke_cli(): if __name__ == "__main__": - if config.version: - print(f"InvokeAI version {__version__}") - else: - invoke_cli() + invoke_cli() diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index b2634c2c56e..360094fe740 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -13,8 +13,8 @@ SDXLConditioningInfo, ) -from ...backend.model_management.lora import ModelPatcher -from ...backend.model_management.models import ModelNotFoundException, ModelType +from ...backend.model_manager import ModelType, UnknownModelException +from ...backend.model_manager.lora import ModelPatcher from ...backend.util.devices import torch_dtype from .baseinvocation import ( BaseInvocation, @@ -60,23 +60,23 @@ class CompelInvocation(BaseInvocation): @torch.no_grad() def invoke(self, context: InvocationContext) -> ConditioningOutput: - tokenizer_info = context.services.model_manager.get_model( + tokenizer_info = context.services.model_loader.get_model( **self.clip.tokenizer.dict(), context=context, ) - text_encoder_info = context.services.model_manager.get_model( + text_encoder_info = context.services.model_loader.get_model( **self.clip.text_encoder.dict(), context=context, ) def _lora_loader(): for lora in self.clip.loras: - lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context) + lora_info = context.services.model_loader.get_model(**lora.dict(exclude={"weight"}), context=context) yield (lora_info.context.model, lora.weight) del lora_info return - # loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] + # loras = [(context.services.model_loader.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] ti_list = [] for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt): @@ -85,7 +85,7 @@ def _lora_loader(): ti_list.append( ( name, - context.services.model_manager.get_model( + context.services.model_loader.get_model( model_name=name, base_model=self.clip.text_encoder.base_model, model_type=ModelType.TextualInversion, @@ -93,7 +93,7 @@ def _lora_loader(): ).context.model, ) ) - except ModelNotFoundException: + except UnknownModelException: # print(e) # import traceback # print(traceback.format_exc()) @@ -159,11 +159,11 @@ def run_clip_compel( lora_prefix: str, zero_on_empty: bool, ): - tokenizer_info = context.services.model_manager.get_model( + tokenizer_info = context.services.model_loader.get_model( **clip_field.tokenizer.dict(), context=context, ) - text_encoder_info = context.services.model_manager.get_model( + text_encoder_info = context.services.model_loader.get_model( **clip_field.text_encoder.dict(), context=context, ) @@ -186,12 +186,12 @@ def run_clip_compel( def _lora_loader(): for lora in clip_field.loras: - lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context) + lora_info = context.services.model_loader.get_model(**lora.dict(exclude={"weight"}), context=context) yield (lora_info.context.model, lora.weight) del lora_info return - # loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] + # loras = [(context.services.model_loader.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] ti_list = [] for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt): @@ -200,7 +200,7 @@ def _lora_loader(): ti_list.append( ( name, - context.services.model_manager.get_model( + context.services.model_loader.get_model( model_name=name, base_model=clip_field.text_encoder.base_model, model_type=ModelType.TextualInversion, @@ -208,7 +208,7 @@ def _lora_loader(): ).context.model, ) ) - except ModelNotFoundException: + except UnknownModelException: # print(e) # import traceback # print(traceback.format_exc()) diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 933c32c9080..6712fc4adec 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -28,7 +28,7 @@ from invokeai.app.invocations.primitives import ImageField, ImageOutput -from ...backend.model_management import BaseModelType +from ...backend.model_manager import BaseModelType from ..models.image import ImageCategory, ResourceOrigin from .baseinvocation import ( BaseInvocation, diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index 3e3a3d9b1f4..e58f3f140ea 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -17,8 +17,8 @@ invocation_output, ) from invokeai.app.invocations.primitives import ImageField -from invokeai.backend.model_management.models.base import BaseModelType, ModelType -from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id +from invokeai.backend.model_manager import BaseModelType, ModelType +from invokeai.backend.model_manager.models.ip_adapter import get_ip_adapter_image_encoder_model_id class IPAdapterModelField(BaseModel): diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index c6bf37bdbcd..a49194d0d72 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -37,12 +37,11 @@ from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus -from invokeai.backend.model_management.models import ModelType, SilenceWarnings +from invokeai.backend.model_manager import BaseModelType, ModelType, SilenceWarnings from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo -from ...backend.model_management.lora import ModelPatcher -from ...backend.model_management.models import BaseModelType -from ...backend.model_management.seamless import set_seamless +from ...backend.model_manager.lora import ModelPatcher +from ...backend.model_manager.seamless import set_seamless from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion.diffusers_pipeline import ( ControlNetData, @@ -133,7 +132,7 @@ def invoke(self, context: InvocationContext) -> DenoiseMaskOutput: ) if image is not None: - vae_info = context.services.model_manager.get_model( + vae_info = context.services.model_loader.get_model( **self.vae.vae.dict(), context=context, ) @@ -166,7 +165,7 @@ def get_scheduler( seed: int, ) -> Scheduler: scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"]) - orig_scheduler_info = context.services.model_manager.get_model( + orig_scheduler_info = context.services.model_loader.get_model( **scheduler_info.dict(), context=context, ) @@ -362,7 +361,7 @@ def prep_control_data( controlnet_data = [] for control_info in control_list: control_model = exit_stack.enter_context( - context.services.model_manager.get_model( + context.services.model_loader.get_model( model_name=control_info.control_model.model_name, model_type=ModelType.ControlNet, base_model=control_info.control_model.base_model, @@ -430,7 +429,7 @@ def prep_ip_adapter_data( conditioning_data.ip_adapter_conditioning = [] for single_ip_adapter in ip_adapter: ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context( - context.services.model_manager.get_model( + context.services.model_loader.get_model( model_name=single_ip_adapter.ip_adapter_model.model_name, model_type=ModelType.IPAdapter, base_model=single_ip_adapter.ip_adapter_model.base_model, @@ -438,7 +437,7 @@ def prep_ip_adapter_data( ) ) - image_encoder_model_info = context.services.model_manager.get_model( + image_encoder_model_info = context.services.model_loader.get_model( model_name=single_ip_adapter.image_encoder_model.model_name, model_type=ModelType.CLIPVision, base_model=single_ip_adapter.image_encoder_model.base_model, @@ -488,7 +487,7 @@ def run_t2i_adapters( t2i_adapter_data = [] for t2i_adapter_field in t2i_adapter: - t2i_adapter_model_info = context.services.model_manager.get_model( + t2i_adapter_model_info = context.services.model_loader.get_model( model_name=t2i_adapter_field.t2i_adapter_model.model_name, model_type=ModelType.T2IAdapter, base_model=t2i_adapter_field.t2i_adapter_model.base_model, @@ -640,7 +639,7 @@ def step_callback(state: PipelineIntermediateState): def _lora_loader(): for lora in self.unet.loras: - lora_info = context.services.model_manager.get_model( + lora_info = context.services.model_loader.get_model( **lora.dict(exclude={"weight"}), context=context, ) @@ -648,7 +647,7 @@ def _lora_loader(): del lora_info return - unet_info = context.services.model_manager.get_model( + unet_info = context.services.model_loader.get_model( **self.unet.unet.dict(), context=context, ) @@ -753,7 +752,7 @@ class LatentsToImageInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> ImageOutput: latents = context.services.latents.get(self.latents.latents_name) - vae_info = context.services.model_manager.get_model( + vae_info = context.services.model_loader.get_model( **self.vae.vae.dict(), context=context, ) @@ -978,7 +977,7 @@ def vae_encode(vae_info, upcast, tiled, image_tensor): def invoke(self, context: InvocationContext) -> LatentsOutput: image = context.services.images.get_pil_image(self.image.image_name) - vae_info = context.services.model_manager.get_model( + vae_info = context.services.model_loader.get_model( **self.vae.vae.dict(), context=context, ) diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 571cb2e7303..21d04604dbf 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -3,7 +3,8 @@ from pydantic import BaseModel, Field -from ...backend.model_management import BaseModelType, ModelType, SubModelType +from invokeai.backend.model_manager import SubModelType + from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, @@ -19,9 +20,7 @@ class ModelInfo(BaseModel): - model_name: str = Field(description="Info to load submodel") - base_model: BaseModelType = Field(description="Base model") - model_type: ModelType = Field(description="Info to load submodel") + key: str = Field(description="Unique ID for model") submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel") @@ -61,16 +60,13 @@ class ModelLoaderOutput(BaseInvocationOutput): class MainModelField(BaseModel): """Main model field""" - model_name: str = Field(description="Name of the model") - base_model: BaseModelType = Field(description="Base model") - model_type: ModelType = Field(description="Model Type") + key: str = Field(description="Unique ID of the model") class LoRAModelField(BaseModel): """LoRA model field""" - model_name: str = Field(description="Name of the LoRA model") - base_model: BaseModelType = Field(description="Base model") + key: str = Field(description="Unique ID for model") @invocation("main_model_loader", title="Main Model", tags=["model"], category="model", version="1.0.0") @@ -81,20 +77,15 @@ class MainModelLoaderInvocation(BaseInvocation): # TODO: precision? def invoke(self, context: InvocationContext) -> ModelLoaderOutput: - base_model = self.model.base_model - model_name = self.model.model_name - model_type = ModelType.Main + """Load a main model, outputting its submodels.""" + key = self.model.key # TODO: not found exceptions - if not context.services.model_manager.model_exists( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ): - raise Exception(f"Unknown {base_model} {model_type} model: {model_name}") + if not context.services.model_record_store.model_exists(key): + raise Exception(f"Unknown model {key}") """ - if not context.services.model_manager.model_exists( + if not context.services.model_record_store.model_exists( model_name=self.model_name, model_type=SDModelType.Diffusers, submodel=SDModelType.Tokenizer, @@ -103,7 +94,7 @@ def invoke(self, context: InvocationContext) -> ModelLoaderOutput: f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted" ) - if not context.services.model_manager.model_exists( + if not context.services.model_record_store.model_exists( model_name=self.model_name, model_type=SDModelType.Diffusers, submodel=SDModelType.TextEncoder, @@ -112,7 +103,7 @@ def invoke(self, context: InvocationContext) -> ModelLoaderOutput: f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted" ) - if not context.services.model_manager.model_exists( + if not context.services.model_record_store.model_exists( model_name=self.model_name, model_type=SDModelType.Diffusers, submodel=SDModelType.UNet, @@ -125,30 +116,22 @@ def invoke(self, context: InvocationContext) -> ModelLoaderOutput: return ModelLoaderOutput( unet=UNetField( unet=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=key, submodel=SubModelType.UNet, ), scheduler=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=key, submodel=SubModelType.Scheduler, ), loras=[], ), clip=ClipField( tokenizer=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=key, submodel=SubModelType.Tokenizer, ), text_encoder=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=key, submodel=SubModelType.TextEncoder, ), loras=[], @@ -156,9 +139,7 @@ def invoke(self, context: InvocationContext) -> ModelLoaderOutput: ), vae=VaeField( vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=key, submodel=SubModelType.Vae, ), ), @@ -167,7 +148,7 @@ def invoke(self, context: InvocationContext) -> ModelLoaderOutput: @invocation_output("lora_loader_output") class LoraLoaderOutput(BaseInvocationOutput): - """Model loader output""" + """Model loader output.""" unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet") clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP") @@ -187,24 +168,20 @@ class LoraLoaderInvocation(BaseInvocation): ) def invoke(self, context: InvocationContext) -> LoraLoaderOutput: + """Load a LoRA model.""" if self.lora is None: raise Exception("No LoRA provided") - base_model = self.lora.base_model - lora_name = self.lora.model_name + key = self.lora.key - if not context.services.model_manager.model_exists( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, - ): - raise Exception(f"Unkown lora name: {lora_name}!") + if not context.services.model_record_store.model_exists(key): + raise Exception(f"Unknown lora: {key}!") - if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras): - raise Exception(f'Lora "{lora_name}" already applied to unet') + if self.unet is not None and any(lora.key == key for lora in self.unet.loras): + raise Exception(f'Lora "{key}" already applied to unet') - if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras): - raise Exception(f'Lora "{lora_name}" already applied to clip') + if self.clip is not None and any(lora.key == key for lora in self.clip.loras): + raise Exception(f'Lora "{key}" already applied to clip') output = LoraLoaderOutput() @@ -212,9 +189,7 @@ def invoke(self, context: InvocationContext) -> LoraLoaderOutput: output.unet = copy.deepcopy(self.unet) output.unet.loras.append( LoraInfo( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, + key=key, submodel=None, weight=self.weight, ) @@ -224,9 +199,7 @@ def invoke(self, context: InvocationContext) -> LoraLoaderOutput: output.clip = copy.deepcopy(self.clip) output.clip.loras.append( LoraInfo( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, + key=key, submodel=None, weight=self.weight, ) @@ -237,7 +210,7 @@ def invoke(self, context: InvocationContext) -> LoraLoaderOutput: @invocation_output("sdxl_lora_loader_output") class SDXLLoraLoaderOutput(BaseInvocationOutput): - """SDXL LoRA Loader Output""" + """SDXL LoRA Loader Output.""" unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet") clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1") @@ -261,27 +234,22 @@ class SDXLLoraLoaderInvocation(BaseInvocation): ) def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: + """Load an SDXL LoRA.""" if self.lora is None: raise Exception("No LoRA provided") - base_model = self.lora.base_model - lora_name = self.lora.model_name + key = self.lora.key + if not context.services.model_record_store.model_exists(key): + raise Exception(f"Unknown lora name: {key}!") - if not context.services.model_manager.model_exists( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, - ): - raise Exception(f"Unknown lora name: {lora_name}!") - - if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras): - raise Exception(f'Lora "{lora_name}" already applied to unet') + if self.unet is not None and any(lora.key == key for lora in self.unet.loras): + raise Exception(f'Lora "{key}" already applied to unet') - if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras): - raise Exception(f'Lora "{lora_name}" already applied to clip') + if self.clip is not None and any(lora.key == key for lora in self.clip.loras): + raise Exception(f'Lora "{key}" already applied to clip') - if self.clip2 is not None and any(lora.model_name == lora_name for lora in self.clip2.loras): - raise Exception(f'Lora "{lora_name}" already applied to clip2') + if self.clip2 is not None and any(lora.key == key for lora in self.clip2.loras): + raise Exception(f'Lora "{key}" already applied to clip2') output = SDXLLoraLoaderOutput() @@ -289,9 +257,7 @@ def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: output.unet = copy.deepcopy(self.unet) output.unet.loras.append( LoraInfo( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, + key=key, submodel=None, weight=self.weight, ) @@ -301,9 +267,7 @@ def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: output.clip = copy.deepcopy(self.clip) output.clip.loras.append( LoraInfo( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, + key=key, submodel=None, weight=self.weight, ) @@ -313,9 +277,7 @@ def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: output.clip2 = copy.deepcopy(self.clip2) output.clip2.loras.append( LoraInfo( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, + key=key, submodel=None, weight=self.weight, ) @@ -325,10 +287,9 @@ def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: class VAEModelField(BaseModel): - """Vae model field""" + """Vae model field.""" - model_name: str = Field(description="Name of the model") - base_model: BaseModelType = Field(description="Base model") + key: str = Field(description="Unique ID for VAE model") @invocation_output("vae_loader_output") @@ -340,29 +301,22 @@ class VaeLoaderOutput(BaseInvocationOutput): @invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.0") class VaeLoaderInvocation(BaseInvocation): - """Loads a VAE model, outputting a VaeLoaderOutput""" + """Loads a VAE model, outputting a VaeLoaderOutput.""" vae_model: VAEModelField = InputField( description=FieldDescriptions.vae_model, input=Input.Direct, ui_type=UIType.VaeModel, title="VAE" ) def invoke(self, context: InvocationContext) -> VaeLoaderOutput: - base_model = self.vae_model.base_model - model_name = self.vae_model.model_name - model_type = ModelType.Vae - - if not context.services.model_manager.model_exists( - base_model=base_model, - model_name=model_name, - model_type=model_type, - ): - raise Exception(f"Unkown vae name: {model_name}!") + """Load a VAE model.""" + key = self.vae_model.key + + if not context.services.model_record_store.model_exists(key): + raise Exception(f"Unkown vae name: {key}!") return VaeLoaderOutput( vae=VaeField( vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=key, ) ) ) @@ -370,7 +324,7 @@ def invoke(self, context: InvocationContext) -> VaeLoaderOutput: @invocation_output("seamless_output") class SeamlessModeOutput(BaseInvocationOutput): - """Modified Seamless Model output""" + """Modified Seamless Model output.""" unet: Optional[UNetField] = OutputField(description=FieldDescriptions.unet, title="UNet") vae: Optional[VaeField] = OutputField(description=FieldDescriptions.vae, title="VAE") @@ -390,6 +344,7 @@ class SeamlessModeInvocation(BaseInvocation): seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless") def invoke(self, context: InvocationContext) -> SeamlessModeOutput: + """Apply seamless transformation.""" # Conditionally append 'x' and 'y' based on seamless_x and seamless_y unet = copy.deepcopy(self.unet) vae = copy.deepcopy(self.vae) diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py index 1d531d45a26..0b008e89bb9 100644 --- a/invokeai/app/invocations/onnx.py +++ b/invokeai/app/invocations/onnx.py @@ -17,7 +17,7 @@ from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.backend import BaseModelType, ModelType, SubModelType -from ...backend.model_management import ONNXModelPatcher +from ...backend.model_manager.lora import ONNXModelPatcher from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.util import choose_torch_device from ..models.image import ImageCategory, ResourceOrigin @@ -62,15 +62,15 @@ class ONNXPromptInvocation(BaseInvocation): clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection) def invoke(self, context: InvocationContext) -> ConditioningOutput: - tokenizer_info = context.services.model_manager.get_model( + tokenizer_info = context.services.model_loader.get_model( **self.clip.tokenizer.dict(), ) - text_encoder_info = context.services.model_manager.get_model( + text_encoder_info = context.services.model_loader.get_model( **self.clip.text_encoder.dict(), ) with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder: # , ExitStack() as stack: loras = [ - (context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) + (context.services.model_loader.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras ] @@ -81,7 +81,7 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput: ti_list.append( ( name, - context.services.model_manager.get_model( + context.services.model_loader.get_model( model_name=name, base_model=self.clip.text_encoder.base_model, model_type=ModelType.TextualInversion, @@ -254,12 +254,12 @@ def dispatch_progress( eta=0.0, ) - unet_info = context.services.model_manager.get_model(**self.unet.unet.dict()) + unet_info = context.services.model_loader.get_model(**self.unet.unet.dict()) with unet_info as unet: # , ExitStack() as stack: - # loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras] + # loras = [(stack.enter_context(context.services.model_loader.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras] loras = [ - (context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) + (context.services.model_loader.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras ] @@ -345,7 +345,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: if self.vae.vae.submodel != SubModelType.VaeDecoder: raise Exception(f"Expected vae_decoder, found: {self.vae.vae.model_type}") - vae_info = context.services.model_manager.get_model( + vae_info = context.services.model_loader.get_model( **self.vae.vae.dict(), ) @@ -418,7 +418,7 @@ def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput: model_type = ModelType.ONNX # TODO: not found exceptions - if not context.services.model_manager.model_exists( + if not context.services.model_record_store.model_exists( model_name=model_name, base_model=base_model, model_type=model_type, @@ -426,7 +426,7 @@ def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput: raise Exception(f"Unknown {base_model} {model_type} model: {model_name}") """ - if not context.services.model_manager.model_exists( + if not context.services.model_record_store.model_exists( model_name=self.model_name, model_type=SDModelType.Diffusers, submodel=SDModelType.Tokenizer, @@ -435,7 +435,7 @@ def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput: f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted" ) - if not context.services.model_manager.model_exists( + if not context.services.model_record_store.model_exists( model_name=self.model_name, model_type=SDModelType.Diffusers, submodel=SDModelType.TextEncoder, @@ -444,7 +444,7 @@ def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput: f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted" ) - if not context.services.model_manager.model_exists( + if not context.services.model_record_store.model_exists( model_name=self.model_name, model_type=SDModelType.Diffusers, submodel=SDModelType.UNet, diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index de4ea604b44..fadb6320649 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -1,4 +1,4 @@ -from ...backend.model_management import ModelType, SubModelType +from ...backend.model_manager import ModelType, SubModelType from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, @@ -48,7 +48,7 @@ def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput: model_type = ModelType.Main # TODO: not found exceptions - if not context.services.model_manager.model_exists( + if not context.services.model_record_store.model_exists( model_name=model_name, base_model=base_model, model_type=model_type, @@ -137,7 +137,7 @@ def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput: model_type = ModelType.Main # TODO: not found exceptions - if not context.services.model_manager.model_exists( + if not context.services.model_record_store.model_exists( model_name=model_name, base_model=base_model, model_type=model_type, diff --git a/invokeai/app/invocations/t2i_adapter.py b/invokeai/app/invocations/t2i_adapter.py index e1bd8d0d04a..19b3d204513 100644 --- a/invokeai/app/invocations/t2i_adapter.py +++ b/invokeai/app/invocations/t2i_adapter.py @@ -16,7 +16,7 @@ ) from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES from invokeai.app.invocations.primitives import ImageField -from invokeai.backend.model_management.models.base import BaseModelType +from invokeai.backend.model_manager import BaseModelType class T2IAdapterModelField(BaseModel): diff --git a/invokeai/app/services/config/base.py b/invokeai/app/services/config/base.py index f24879af052..d3b7b76fb50 100644 --- a/invokeai/app/services/config/base.py +++ b/invokeai/app/services/config/base.py @@ -25,6 +25,7 @@ class PagingArgumentParser(argparse.ArgumentParser): """ A custom ArgumentParser that uses pydoc to page its output. + It also supports reading defaults from an init file. """ @@ -144,16 +145,6 @@ def _excluded_from_yaml(cls) -> List[str]: return [ "type", "initconf", - "version", - "from_file", - "model", - "root", - "max_cache_size", - "max_vram_cache_size", - "always_use_cpu", - "free_gpu_mem", - "xformers_enabled", - "tiled_decode", ] class Config: @@ -226,9 +217,7 @@ def add_field_argument(cls, command_parser, name: str, field, default_override=N def int_or_float_or_str(value: str) -> Union[int, float, str]: - """ - Workaround for argparse type checking. - """ + """Workaround for argparse type checking.""" try: return int(value) except Exception as e: # noqa F841 diff --git a/invokeai/app/services/config/invokeai_config.py b/invokeai/app/services/config/invokeai_config.py index d8b598815d1..d17bca77910 100644 --- a/invokeai/app/services/config/invokeai_config.py +++ b/invokeai/app/services/config/invokeai_config.py @@ -171,6 +171,7 @@ class InvokeBatch(InvokeAISettings): from __future__ import annotations import os +import sys from pathlib import Path from typing import ClassVar, Dict, List, Literal, Optional, Union, get_type_hints @@ -182,7 +183,9 @@ class InvokeBatch(InvokeAISettings): INIT_FILE = Path("invokeai.yaml") DB_FILE = Path("invokeai.db") LEGACY_INIT_FILE = Path("invokeai.init") -DEFAULT_MAX_VRAM = 0.5 +DEFAULT_MAX_DISK_CACHE = 20 # gigs, enough for three sdxl models, or 6 sd-1 models +DEFAULT_RAM_CACHE = 7.5 +DEFAULT_VRAM_CACHE = 0.25 class InvokeAIAppConfig(InvokeAISettings): @@ -217,11 +220,8 @@ class InvokeAIAppConfig(InvokeAISettings): # PATHS root : Path = Field(default=None, description='InvokeAI runtime root directory', category='Paths') - autoimport_dir : Path = Field(default='autoimport', description='Path to a directory of models files to be imported on startup.', category='Paths') - lora_dir : Path = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', category='Paths') - embedding_dir : Path = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', category='Paths') - controlnet_dir : Path = Field(default=None, description='Path to a directory of ControlNet embeddings to be imported on startup.', category='Paths') - conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths') + autoimport_dir : Optional[Path] = Field(default=None, description='Path to a directory of models files to be imported on startup.', category='Paths') + model_config_db : Union[Path, Literal['auto'], None] = Field(default=None, description='Path to a sqlite .db file or .yaml file for storing model config records; "auto" will reuse the main sqlite db', category='Paths') models_dir : Path = Field(default='models', description='Path to the models directory', category='Paths') legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths') db_dir : Path = Field(default='databases', description='Path to InvokeAI databases directory', category='Paths') @@ -241,8 +241,9 @@ class InvokeAIAppConfig(InvokeAISettings): version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other") # CACHE - ram : Union[float, Literal["auto"]] = Field(default=7.5, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number or 'auto')", category="Model Cache", ) - vram : Union[float, Literal["auto"]] = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number or 'auto')", category="Model Cache", ) + ram : float = Field(default=DEFAULT_RAM_CACHE, gt=0, description="Maximum memory amount used by model cache for rapid switching", category="Model Cache", ) + vram : float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage", category="Model Cache", ) + disk : float = Field(default=DEFAULT_MAX_DISK_CACHE, ge=0, description="Maximum size (in GB) for the disk-based diffusers model conversion cache", category="Model Cache", ) lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", category="Model Cache", ) # DEVICE @@ -254,7 +255,6 @@ class InvokeAIAppConfig(InvokeAISettings): attention_type : Literal["auto", "normal", "xformers", "sliced", "torch-sdp"] = Field(default="auto", description="Attention type", category="Generation", ) attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", ) force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",) - force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",) png_compress_level : int = Field(default=6, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = fastest, largest filesize, 9 = slowest, smallest filesize", category="Generation", ) # QUEUE @@ -272,6 +272,10 @@ class InvokeAIAppConfig(InvokeAISettings): max_vram_cache_size : Optional[float] = Field(default=None, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance') xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance') tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance') + conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths') + lora_dir : Path = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', category='Paths') + embedding_dir : Path = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', category='Paths') + controlnet_dir : Path = Field(default=None, description='Path to a directory of ControlNet embeddings to be imported on startup.', category='Paths') # See InvokeAIAppConfig subclass below for CACHE and DEVICE categories # fmt: on @@ -312,9 +316,7 @@ def parse_args(self, argv: Optional[list[str]] = None, conf: Optional[DictConfig @classmethod def get_config(cls, **kwargs) -> InvokeAIAppConfig: - """ - This returns a singleton InvokeAIAppConfig configuration object. - """ + """This returns a singleton InvokeAIAppConfig configuration object.""" if ( cls.singleton_config is None or type(cls.singleton_config) is not cls @@ -324,6 +326,29 @@ def get_config(cls, **kwargs) -> InvokeAIAppConfig: cls.singleton_init = kwargs return cls.singleton_config + @classmethod + def _excluded_from_yaml(cls) -> List[str]: + el = super()._excluded_from_yaml() + el.extend( + [ + "version", + "from_file", + "model", + "root", + "max_cache_size", + "max_vram_cache_size", + "always_use_cpu", + "free_gpu_mem", + "xformers_enabled", + "tiled_decode", + "conf_path", + "lora_dir", + "embedding_dir", + "controlnet_dir", + ] + ) + return el + @property def root_path(self) -> Path: """ @@ -414,7 +439,11 @@ def ram_cache_size(self) -> Union[Literal["auto"], float]: return self.max_cache_size or self.ram @property - def vram_cache_size(self) -> Union[Literal["auto"], float]: + def conversion_cache_size(self) -> float: + return self.disk + + @property + def vram_cache_size(self) -> float: return self.max_vram_cache_size or self.vram @property @@ -440,9 +469,7 @@ def find_root() -> Path: def get_invokeai_config(**kwargs) -> InvokeAIAppConfig: - """ - Legacy function which returns InvokeAIAppConfig.get_config() - """ + """Legacy function which returns InvokeAIAppConfig.get_config().""" return InvokeAIAppConfig.get_config(**kwargs) diff --git a/invokeai/app/services/download_manager.py b/invokeai/app/services/download_manager.py new file mode 100644 index 00000000000..fd0de798fff --- /dev/null +++ b/invokeai/app/services/download_manager.py @@ -0,0 +1,205 @@ +# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team +""" +Model download service. +""" + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import TYPE_CHECKING, Any, List, Optional, Union + +from pydantic.networks import AnyHttpUrl + +from invokeai.backend.model_manager.download import DownloadJobRemoteSource # noqa F401 +from invokeai.backend.model_manager.download import ( # noqa F401 + DownloadEventHandler, + DownloadJobBase, + DownloadJobPath, + DownloadJobStatus, + DownloadQueueBase, + ModelDownloadQueue, + ModelSourceMetadata, + UnknownJobIDException, +) + +if TYPE_CHECKING: + from .events import EventServiceBase + + +class DownloadQueueServiceBase(ABC): + """Multithreaded queue for downloading models via URL or repo_id.""" + + @abstractmethod + def create_download_job( + self, + source: Union[str, Path, AnyHttpUrl], + destdir: Path, + filename: Optional[Path] = None, + start: Optional[bool] = True, + access_token: Optional[str] = None, + event_handlers: Optional[List[DownloadEventHandler]] = None, + ) -> DownloadJobBase: + """ + Create a download job. + + :param source: Source of the download - URL, repo_id or local Path + :param destdir: Directory to download into. + :param filename: Optional name of file, if not provided + will use the content-disposition field to assign the name. + :param start: Immediately start job [True] + :param event_handler: Callable that receives a DownloadJobBase and acts on it. + :returns job id: The numeric ID of the DownloadJobBase object for this task. + """ + pass + + @abstractmethod + def submit_download_job( + self, + job: DownloadJobBase, + start: Optional[bool] = True, + ): + """ + Submit a download job. + + :param job: A DownloadJobBase + :param start: Immediately start job [True] + + After execution, `job.id` will be set to a non-negative value. + """ + pass + + @abstractmethod + def list_jobs(self) -> List[DownloadJobBase]: + """ + List active DownloadJobBases. + + :returns List[DownloadJobBase]: List of download jobs whose state is not "completed." + """ + pass + + @abstractmethod + def id_to_job(self, id: int) -> DownloadJobBase: + """ + Return the DownloadJobBase corresponding to the string ID. + + :param id: ID of the DownloadJobBase. + + Exceptions: + * UnknownJobIDException + """ + pass + + @abstractmethod + def start_all_jobs(self): + """Enqueue all idle and paused jobs.""" + pass + + @abstractmethod + def pause_all_jobs(self): + """Pause and dequeue all active jobs.""" + pass + + @abstractmethod + def cancel_all_jobs(self): + """Cancel all active and enquedjobs.""" + pass + + @abstractmethod + def prune_jobs(self): + """Prune completed and errored queue items from the job list.""" + pass + + @abstractmethod + def start_job(self, job: DownloadJobBase): + """Start the job putting it into ENQUEUED state.""" + pass + + @abstractmethod + def pause_job(self, job: DownloadJobBase): + """Pause the job, putting it into PAUSED state.""" + pass + + @abstractmethod + def cancel_job(self, job: DownloadJobBase): + """Cancel the job, clearing partial downloads and putting it into ERROR state.""" + pass + + @abstractmethod + def join(self): + """Wait until all jobs are off the queue.""" + pass + + +class DownloadQueueService(DownloadQueueServiceBase): + """Multithreaded queue for downloading models via URL or repo_id.""" + + _event_bus: Optional["EventServiceBase"] = None + _queue: DownloadQueueBase + + def __init__(self, event_bus: Optional["EventServiceBase"] = None, **kwargs): + """ + Initialize new DownloadQueueService object. + + :param event_bus: EventServiceBase object for reporting progress. + :param **kwargs: Any of the arguments taken by invokeai.backend.model_manager.download.DownloadQueue. + e.g. `max_parallel_dl`. + """ + self._event_bus = event_bus + self._queue = ModelDownloadQueue(**kwargs) + + def create_download_job( + self, + source: Union[str, Path, AnyHttpUrl], + destdir: Path, + filename: Optional[Path] = None, + start: Optional[bool] = True, + access_token: Optional[str] = None, + event_handlers: Optional[List[DownloadEventHandler]] = None, + ) -> DownloadJobBase: # noqa D102 + event_handlers = event_handlers or [] + if self._event_bus: + event_handlers = [*event_handlers, self._event_bus.emit_model_event] + return self._queue.create_download_job( + source=source, + destdir=destdir, + filename=filename, + start=start, + access_token=access_token, + event_handlers=event_handlers, + ) + + def submit_download_job( + self, + job: DownloadJobBase, + start: bool = True, + ): + return self._queue.submit_download_job(job, start) + + def list_jobs(self) -> List[DownloadJobBase]: # noqa D102 + return self._queue.list_jobs() + + def id_to_job(self, id: int) -> DownloadJobBase: # noqa D102 + return self._queue.id_to_job(id) + + def start_all_jobs(self): # noqa D102 + return self._queue.start_all_jobs() + + def pause_all_jobs(self): # noqa D102 + return self._queue.pause_all_jobs() + + def cancel_all_jobs(self): # noqa D102 + return self._queue.cancel_all_jobs() + + def prune_jobs(self): # noqa D102 + return self._queue.prune_jobs() + + def start_job(self, job: DownloadJobBase): # noqa D102 + return self._queue.start_job(job) + + def pause_job(self, job: DownloadJobBase): # noqa D102 + return self._queue.pause_job(job) + + def cancel_job(self, job: DownloadJobBase): # noqa D102 + return self._queue.cancel_job(job) + + def join(self): # noqa D102 + return self._queue.join() diff --git a/invokeai/app/services/events.py b/invokeai/app/services/events.py index 0a02a03539d..f52323c46a3 100644 --- a/invokeai/app/services/events.py +++ b/invokeai/app/services/events.py @@ -3,7 +3,7 @@ from typing import Any, Optional from invokeai.app.models.image import ProgressImage -from invokeai.app.services.model_manager_service import BaseModelType, ModelInfo, ModelType, SubModelType +from invokeai.app.services.model_record_service import BaseModelType, ModelType, SubModelType from invokeai.app.services.session_queue.session_queue_common import ( BatchStatus, EnqueueBatchResult, @@ -11,14 +11,17 @@ SessionQueueStatus, ) from invokeai.app.util.misc import get_timestamp +from invokeai.backend.model_manager import SubModelType +from invokeai.backend.model_manager.download import DownloadJobBase +from invokeai.backend.model_manager.loader import ModelInfo +from invokeai.backend.util.logging import InvokeAILogger class EventServiceBase: queue_event: str = "queue_event" - """Basic event bus, to have an empty stand-in when not needed""" - def dispatch(self, event_name: str, payload: Any) -> None: + """Dispatch an event.""" pass def __emit_queue_event(self, event_name: str, payload: dict) -> None: @@ -153,9 +156,7 @@ def emit_model_load_started( queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, + model_key: str, submodel: SubModelType, ) -> None: """Emitted when a model is requested""" @@ -166,9 +167,7 @@ def emit_model_load_started( queue_item_id=queue_item_id, queue_batch_id=queue_batch_id, graph_execution_state_id=graph_execution_state_id, - model_name=model_name, - base_model=base_model, - model_type=model_type, + model_key=model_key, submodel=submodel, ), ) @@ -179,9 +178,7 @@ def emit_model_load_completed( queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, + model_key: str, submodel: SubModelType, model_info: ModelInfo, ) -> None: @@ -193,9 +190,7 @@ def emit_model_load_completed( queue_item_id=queue_item_id, queue_batch_id=queue_batch_id, graph_execution_state_id=graph_execution_state_id, - model_name=model_name, - base_model=base_model, - model_type=model_type, + model_key=model_key, submodel=submodel, hash=model_info.hash, location=str(model_info.location), @@ -312,3 +307,9 @@ def emit_queue_cleared(self, queue_id: str) -> None: event_name="queue_cleared", payload=dict(queue_id=queue_id), ) + + def emit_model_event(self, job: DownloadJobBase) -> None: + """Emit event when the status of a download/install job changes.""" + self.dispatch( # use dispatch() directly here because we are not a session event. + event_name="model_event", payload=dict(job=job) + ) diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index e496ff80f27..431d6aea05b 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -9,6 +9,7 @@ from invokeai.app.services.board_images import BoardImagesServiceABC from invokeai.app.services.boards import BoardServiceABC from invokeai.app.services.config import InvokeAIAppConfig + from invokeai.app.services.download_manager import DownloadQueueServiceBase from invokeai.app.services.events import EventServiceBase from invokeai.app.services.graph import GraphExecutionState, LibraryGraph from invokeai.app.services.images import ImageServiceABC @@ -18,7 +19,9 @@ from invokeai.app.services.invoker import InvocationProcessorABC from invokeai.app.services.item_storage import ItemStorageABC from invokeai.app.services.latent_storage import LatentsStorageBase - from invokeai.app.services.model_manager_service import ModelManagerServiceBase + from invokeai.app.services.model_install_service import ModelInstallServiceBase + from invokeai.app.services.model_loader_service import ModelLoadServiceBase + from invokeai.app.services.model_record_service import ModelRecordServiceBase from invokeai.app.services.session_processor.session_processor_base import SessionProcessorBase from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase @@ -35,8 +38,11 @@ class InvocationServices: graph_library: "ItemStorageABC[LibraryGraph]" images: "ImageServiceABC" latents: "LatentsStorageBase" + download_queue: "DownloadQueueServiceBase" + model_record_store: "ModelRecordServiceBase" + model_loader: "ModelLoadServiceBase" + model_installer: "ModelInstallServiceBase" logger: "Logger" - model_manager: "ModelManagerServiceBase" processor: "InvocationProcessorABC" performance_statistics: "InvocationStatsServiceBase" queue: "InvocationQueueABC" @@ -55,7 +61,10 @@ def __init__( images: "ImageServiceABC", latents: "LatentsStorageBase", logger: "Logger", - model_manager: "ModelManagerServiceBase", + download_queue: "DownloadQueueServiceBase", + model_record_store: "ModelRecordServiceBase", + model_loader: "ModelLoadServiceBase", + model_installer: "ModelInstallServiceBase", processor: "InvocationProcessorABC", performance_statistics: "InvocationStatsServiceBase", queue: "InvocationQueueABC", @@ -72,7 +81,10 @@ def __init__( self.images = images self.latents = latents self.logger = logger - self.model_manager = model_manager + self.download_queue = download_queue + self.model_record_store = model_record_store + self.model_loader = model_loader + self.model_installer = model_installer self.processor = processor self.performance_statistics = performance_statistics self.queue = queue diff --git a/invokeai/app/services/invocation_stats.py b/invokeai/app/services/invocation_stats.py index 33932f73aad..6dff5449b35 100644 --- a/invokeai/app/services/invocation_stats.py +++ b/invokeai/app/services/invocation_stats.py @@ -38,12 +38,12 @@ import torch import invokeai.backend.util.logging as logger -from invokeai.backend.model_management.model_cache import CacheStats +from invokeai.backend.model_manager.cache import CacheStats from ..invocations.baseinvocation import BaseInvocation from .graph import GraphExecutionState from .item_storage import ItemStorageABC -from .model_manager_service import ModelManagerService +from .model_loader_service import ModelLoadServiceBase # size of GIG in bytes GIG = 1073741824 @@ -174,13 +174,13 @@ class StatsContext: graph_id: str start_time: float ram_used: int - model_manager: ModelManagerService + model_loader: ModelLoadServiceBase def __init__( self, invocation: BaseInvocation, graph_id: str, - model_manager: ModelManagerService, + model_loader: ModelLoadServiceBase, collector: "InvocationStatsServiceBase", ): """Initialize statistics for this run.""" @@ -189,15 +189,15 @@ def __init__( self.graph_id = graph_id self.start_time = 0.0 self.ram_used = 0 - self.model_manager = model_manager + self.model_loader = model_loader def __enter__(self): self.start_time = time.time() if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats() self.ram_used = psutil.Process().memory_info().rss - if self.model_manager: - self.model_manager.collect_cache_stats(self.collector._cache_stats[self.graph_id]) + if self.model_loader: + self.model_loader.collect_cache_stats(self.collector._cache_stats[self.graph_id]) def __exit__(self, *args): """Called on exit from the context.""" @@ -208,7 +208,7 @@ def __exit__(self, *args): ) self.collector.update_invocation_stats( graph_id=self.graph_id, - invocation_type=self.invocation.type, # type: ignore - `type` is not on the `BaseInvocation` model, but *is* on all invocations + invocation_type=self.invocation.type, time_used=time.time() - self.start_time, vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0, ) @@ -217,12 +217,12 @@ def collect_stats( self, invocation: BaseInvocation, graph_execution_state_id: str, - model_manager: ModelManagerService, + model_loader: ModelLoadServiceBase, ) -> StatsContext: if not self._stats.get(graph_execution_state_id): # first time we're seeing this self._stats[graph_execution_state_id] = NodeLog() self._cache_stats[graph_execution_state_id] = CacheStats() - return self.StatsContext(invocation, graph_execution_state_id, model_manager, self) + return self.StatsContext(invocation, graph_execution_state_id, model_loader, self) def reset_all_stats(self): """Zero all statistics""" diff --git a/invokeai/app/services/model_convert.py b/invokeai/app/services/model_convert.py new file mode 100644 index 00000000000..d17a5986c80 --- /dev/null +++ b/invokeai/app/services/model_convert.py @@ -0,0 +1,192 @@ +# Copyright 2023 Lincoln Stein and the InvokeAI Team + +""" +Convert and merge models. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from pathlib import Path +from shutil import move, rmtree +from typing import List, Optional + +from pydantic import Field + +from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger + +from .config import InvokeAIAppConfig +from .model_install_service import ModelInstallServiceBase +from .model_loader_service import ModelInfo, ModelLoadServiceBase +from .model_record_service import ModelConfigBase, ModelRecordServiceBase, ModelType, SubModelType + + +class ModelConvertBase(ABC): + """Convert and merge models.""" + + @abstractmethod + def __init__( + cls, + loader: ModelLoadServiceBase, + installer: ModelInstallServiceBase, + store: ModelRecordServiceBase, + ): + """Initialize ModelConvert with loader, installer and configuration store.""" + pass + + @abstractmethod + def convert_model( + self, + key: str, + dest_directory: Optional[Path] = None, + ) -> ModelConfigBase: + """ + Convert a checkpoint file into a diffusers folder. + + It will delete the cached version ans well as the + original checkpoint file if it is in the models directory. + :param key: Unique key of model. + :dest_directory: Optional place to put converted file. If not specified, + will be stored in the `models_dir`. + + This will raise a ValueError unless the model is a checkpoint. + This will raise an UnknownModelException if key is unknown. + """ + pass + + def merge_models( + self, + model_keys: List[str] = Field( + default=None, min_items=2, max_items=3, description="List of model keys to merge" + ), + merged_model_name: Optional[str] = Field(default=None, description="Name of destination model after merging"), + alpha: Optional[float] = 0.5, + interp: Optional[MergeInterpolationMethod] = None, + force: Optional[bool] = False, + merge_dest_directory: Optional[Path] = None, + ) -> ModelConfigBase: + """ + Merge two to three diffusrs pipeline models and save as a new model. + + :param model_keys: List of 2-3 model unique keys to merge + :param merged_model_name: Name of destination merged model + :param alpha: Alpha strength to apply to 2d and 3d model + :param interp: Interpolation method. None (default) + :param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended) + """ + pass + + +class ModelConvert(ModelConvertBase): + """Implementation of ModelConvertBase.""" + + def __init__( + self, + loader: ModelLoadServiceBase, + installer: ModelInstallServiceBase, + store: ModelRecordServiceBase, + ): + """Initialize ModelConvert with loader, installer and configuration store.""" + self.loader = loader + self.installer = installer + self.store = store + + def convert_model( + self, + key: str, + dest_directory: Optional[Path] = None, + ) -> ModelConfigBase: + """ + Convert a checkpoint file into a diffusers folder. + + It will delete the cached version as well as the + original checkpoint file if it is in the models directory. + :param key: Unique key of model. + :dest_directory: Optional place to put converted file. If not specified, + will be stored in the `models_dir`. + + This will raise a ValueError unless the model is a checkpoint. + This will raise an UnknownModelException if key is unknown. + """ + new_diffusers_path = None + config = InvokeAIAppConfig.get_config() + + try: + info: ModelConfigBase = self.store.get_model(key) + + if info.model_format != "checkpoint": + raise ValueError(f"not a checkpoint format model: {info.name}") + + # We are taking advantage of a side effect of get_model() that converts check points + # into cached diffusers directories stored at `path`. It doesn't matter + # what submodel type we request here, so we get the smallest. + submodel = {"submodel_type": SubModelType.Scheduler} if info.model_type == ModelType.Main else {} + converted_model: ModelInfo = self.loader.get_model(key, **submodel) + + checkpoint_path = config.models_path / info.path + old_diffusers_path = config.models_path / converted_model.location + + # new values to write in + update = info.dict() + update.pop("config") + update["model_format"] = "diffusers" + update["path"] = str(converted_model.location) + + if dest_directory: + new_diffusers_path = Path(dest_directory) / info.name + if new_diffusers_path.exists(): + raise ValueError(f"A diffusers model already exists at {new_diffusers_path}") + move(old_diffusers_path, new_diffusers_path) + update["path"] = new_diffusers_path.as_posix() + + self.store.update_model(key, update) + result = self.installer.sync_model_path(key, ignore_hash_change=True) + except Exception as excp: + # something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error! + if new_diffusers_path: + rmtree(new_diffusers_path) + raise excp + + if checkpoint_path.exists() and checkpoint_path.is_relative_to(config.models_path): + checkpoint_path.unlink() + + return result + + def merge_models( + self, + model_keys: List[str] = Field( + default=None, min_items=2, max_items=3, description="List of model keys to merge" + ), + merged_model_name: Optional[str] = Field(default=None, description="Name of destination model after merging"), + alpha: Optional[float] = 0.5, + interp: Optional[MergeInterpolationMethod] = None, + force: Optional[bool] = False, + merge_dest_directory: Optional[Path] = None, + ) -> ModelConfigBase: + """ + Merge two to three diffusrs pipeline models and save as a new model. + + :param model_keys: List of 2-3 model unique keys to merge + :param merged_model_name: Name of destination merged model + :param alpha: Alpha strength to apply to 2d and 3d model + :param interp: Interpolation method. None (default) + :param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended) + """ + pass + merger = ModelMerger(self.store) + try: + if not merged_model_name: + merged_model_name = "+".join([self.store.get_model(x).name for x in model_keys]) + raise Exception("not implemented") + + result = merger.merge_diffusion_models_and_save( + model_keys=model_keys, + merged_model_name=merged_model_name, + alpha=alpha, + interp=interp, + force=force, + merge_dest_directory=merge_dest_directory, + ) + except AssertionError as e: + raise ValueError(e) + return result diff --git a/invokeai/app/services/model_install_service.py b/invokeai/app/services/model_install_service.py new file mode 100644 index 00000000000..cfb1887d4b8 --- /dev/null +++ b/invokeai/app/services/model_install_service.py @@ -0,0 +1,653 @@ +# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team + +import re +import tempfile +from abc import ABC, abstractmethod +from pathlib import Path +from shutil import move, rmtree +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Set, Union + +from pydantic import Field +from pydantic.networks import AnyHttpUrl + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.model_record_service import ModelRecordServiceBase +from invokeai.backend import get_precision +from invokeai.backend.model_manager.config import ( + BaseModelType, + ModelConfigBase, + ModelFormat, + ModelType, + ModelVariantType, + SchedulerPredictionType, + SubModelType, +) +from invokeai.backend.model_manager.download.model_queue import ( + HTTP_RE, + REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, + DownloadJobMetadataURL, + DownloadJobRepoID, + DownloadJobWithMetadata, +) +from invokeai.backend.model_manager.hash import FastModelHash +from invokeai.backend.model_manager.models import InvalidModelException +from invokeai.backend.model_manager.probe import ModelProbe, ModelProbeInfo +from invokeai.backend.model_manager.search import ModelSearch +from invokeai.backend.model_manager.storage import DuplicateModelException, ModelConfigStore +from invokeai.backend.util import Chdir, InvokeAILogger, Logger + +if TYPE_CHECKING: + from .events import EventServiceBase + +from .download_manager import ( + DownloadEventHandler, + DownloadJobBase, + DownloadJobPath, + DownloadQueueService, + DownloadQueueServiceBase, + ModelSourceMetadata, +) + + +class ModelInstallJob(DownloadJobBase): + """This is a version of DownloadJobBase that has an additional slot for the model key and probe info.""" + + model_key: Optional[str] = Field( + description="After model installation, this field will hold its primary key", default=None + ) + probe_override: Optional[Dict[str, Any]] = Field( + description="Keys in this dict will override like-named attributes in the automatic probe info", + default=None, + ) + + +class ModelInstallURLJob(DownloadJobMetadataURL, ModelInstallJob): + """Job for installing URLs.""" + + +class ModelInstallRepoIDJob(DownloadJobRepoID, ModelInstallJob): + """Job for installing repo ids.""" + + +class ModelInstallPathJob(DownloadJobPath, ModelInstallJob): + """Job for installing local paths.""" + + +ModelInstallEventHandler = Callable[["ModelInstallJob"], None] + + +class ModelInstallServiceBase(ABC): + """Abstract base class for InvokeAI model installation.""" + + @abstractmethod + def __init__( + self, + config: Optional[InvokeAIAppConfig] = None, + queue: Optional[DownloadQueueServiceBase] = None, + store: Optional[ModelRecordServiceBase] = None, + event_bus: Optional["EventServiceBase"] = None, + event_handlers: List[DownloadEventHandler] = [], + ): + """ + Create ModelInstallService object. + + :param config: Optional InvokeAIAppConfig. If None passed, + uses the system-wide default app config. + :param download: Optional DownloadQueueServiceBase object. If None passed, + a default queue object will be created. + :param store: Optional ModelConfigStore. If None passed, + defaults to `configs/models.yaml`. + :param event_bus: InvokeAI event bus for reporting events to. + :param event_handlers: List of event handlers to pass to the queue object. + """ + pass + + @property + @abstractmethod + def queue(self) -> DownloadQueueServiceBase: + """Return the download queue used by the installer.""" + pass + + @property + @abstractmethod + def store(self) -> ModelRecordServiceBase: + """Return the storage backend used by the installer.""" + pass + + @property + @abstractmethod + def config(self) -> InvokeAIAppConfig: + """Return the app_config used by the installer.""" + pass + + @abstractmethod + def register_path(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]]) -> str: + """ + Probe and register the model at model_path. + + :param model_path: Filesystem Path to the model. + :param overrides: Dict of attributes that will override probed values. + :returns id: The string ID of the registered model. + """ + pass + + @abstractmethod + def install_path(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None) -> str: + """ + Probe, register and install the model in the models directory. + + This involves moving the model from its current location into + the models directory handled by InvokeAI. + + :param model_path: Filesystem Path to the model. + :param overrides: Dictionary of model probe info fields that, if present, override probed values. + :returns id: The string ID of the installed model. + """ + pass + + @abstractmethod + def install_model( + self, + source: Union[str, Path, AnyHttpUrl], + inplace: bool = True, + priority: int = 10, + start: Optional[bool] = True, + variant: Optional[str] = None, + subfolder: Optional[str] = None, + probe_override: Optional[Dict[str, Any]] = None, + metadata: Optional[ModelSourceMetadata] = None, + access_token: Optional[str] = None, + ) -> ModelInstallJob: + """ + Download and install the indicated model. + + This will download the model located at `source`, + probe it, and install it into the models directory. + This call is executed asynchronously in a separate + thread, and the returned object is a + invokeai.backend.model_manager.download.DownloadJobBase + object which can be interrogated to get the status of + the download and install process. Call our `wait_for_installs()` + method to wait for all downloads and installations to complete. + + :param source: Either a URL or a HuggingFace repo_id. + :param inplace: If True, local paths will not be moved into + the models directory, but registered in place (the default). + :param variant: For HuggingFace models, this optional parameter + specifies which variant to download (e.g. 'fp16') + :param subfolder: When downloading HF repo_ids this can be used to + specify a subfolder of the HF repository to download from. + :param probe_override: Optional dict. Any fields in this dict + will override corresponding probe fields. Use it to override + `base_type`, `model_type`, `format`, `prediction_type` and `image_size`. + :param metadata: Use this to override the fields 'description`, + `author`, `tags`, `source` and `license`. + + :returns ModelInstallJob object. + + The `inplace` flag does not affect the behavior of downloaded + models, which are always moved into the `models` directory. + + Variants recognized by HuggingFace currently are: + 1. onnx + 2. openvino + 3. fp16 + 4. None (usually returns fp32 model) + """ + pass + + @abstractmethod + def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]: + """ + Wait for all pending installs to complete. + + This will block until all pending downloads have + completed, been cancelled, or errored out. It will + block indefinitely if one or more jobs are in the + paused state. + + It will return a dict that maps the source model + path, URL or repo_id to the ID of the installed model. + """ + pass + + @abstractmethod + def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: + """ + Recursively scan directory for new models and register or install them. + + :param scan_dir: Path to the directory to scan. + :param install: Install if True, otherwise register in place. + :returns list of IDs: Returns list of IDs of models registered/installed + """ + pass + + @abstractmethod + def sync_to_config(self): + """Synchronize models on disk to those in memory.""" + pass + + @abstractmethod + def hash(self, model_path: Union[Path, str]) -> str: + """ + Compute and return the fast hash of the model. + + :param model_path: Path to the model on disk. + :return str: FastHash of the model for use as an ID. + """ + pass + + +class ModelInstallService(ModelInstallServiceBase): + """Model installer class handles installation from a local path.""" + + _app_config: InvokeAIAppConfig + _logger: Logger + _store: ModelConfigStore + _download_queue: DownloadQueueServiceBase + _async_installs: Dict[Union[str, Path, AnyHttpUrl], Optional[str]] + _installed: Set[str] = Field(default=set) + _tmpdir: Optional[tempfile.TemporaryDirectory] # used for downloads + _cached_model_paths: Set[Path] = Field(default=set) # used to speed up directory scanning + _precision: Literal["float16", "float32"] = Field(description="Floating point precision, string form") + _event_bus: Optional["EventServiceBase"] = Field(description="an event bus to send install events to", default=None) + + _legacy_configs: Dict[BaseModelType, Dict[ModelVariantType, Union[str, dict]]] = { + BaseModelType.StableDiffusion1: { + ModelVariantType.Normal: "v1-inference.yaml", + ModelVariantType.Inpaint: "v1-inpainting-inference.yaml", + }, + BaseModelType.StableDiffusion2: { + ModelVariantType.Normal: { + SchedulerPredictionType.Epsilon: "v2-inference.yaml", + SchedulerPredictionType.VPrediction: "v2-inference-v.yaml", + }, + ModelVariantType.Inpaint: { + SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml", + SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml", + }, + }, + BaseModelType.StableDiffusionXL: { + ModelVariantType.Normal: "sd_xl_base.yaml", + }, + BaseModelType.StableDiffusionXLRefiner: { + ModelVariantType.Normal: "sd_xl_refiner.yaml", + }, + } + + def __init__( + self, + config: Optional[InvokeAIAppConfig] = None, + queue: Optional[DownloadQueueServiceBase] = None, + store: Optional[ModelRecordServiceBase] = None, + event_bus: Optional["EventServiceBase"] = None, + event_handlers: List[DownloadEventHandler] = [], + ): # noqa D107 - use base class docstrings + self._app_config = config or InvokeAIAppConfig.get_config() + self._store = store or ModelRecordServiceBase.open(self._app_config) + self._logger = InvokeAILogger.get_logger(config=self._app_config) + self._event_bus = event_bus + self._precision = get_precision() + self._handlers = event_handlers + if self._event_bus: + self._handlers.append(self._event_bus.emit_model_event) + + self._download_queue = queue or DownloadQueueService(event_bus=event_bus) + self._async_installs: Dict[Union[str, Path, AnyHttpUrl], Union[str, None]] = dict() + self._installed = set() + self._tmpdir = None + + def start(self, invoker: Any): # Because .processor is giving circular import errors, declaring invoker an 'Any' + """Call automatically at process start.""" + self.sync_to_config() + + @property + def queue(self) -> DownloadQueueServiceBase: + """Return the queue.""" + return self._download_queue + + @property + def store(self) -> ModelConfigStore: + """Return the storage backend used by the installer.""" + return self._store + + @property + def config(self) -> InvokeAIAppConfig: + """Return the app_config used by the installer.""" + return self._app_config + + def install_model( + self, + source: Union[str, Path, AnyHttpUrl], + inplace: bool = True, + priority: int = 10, + start: Optional[bool] = True, + variant: Optional[str] = None, + subfolder: Optional[str] = None, + probe_override: Optional[Dict[str, Any]] = None, + metadata: Optional[ModelSourceMetadata] = None, + access_token: Optional[str] = None, + ) -> ModelInstallJob: # noqa D102 + queue = self._download_queue + variant = variant or ("fp16" if self._precision == "float16" else None) + + job = self._make_download_job( + source, variant=variant, access_token=access_token, subfolder=subfolder, priority=priority + ) + handler = ( + self._complete_registration_handler + if inplace and Path(source).exists() + else self._complete_installation_handler + ) + if isinstance(job, ModelInstallJob): + job.probe_override = probe_override + if metadata and isinstance(job, DownloadJobWithMetadata): + job.metadata = metadata + job.add_event_handler(handler) + + self._async_installs[source] = None + queue.submit_download_job(job, start=start) + return job + + def register_path( + self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None + ) -> str: # noqa D102 + model_path = Path(model_path) + info: ModelProbeInfo = self._probe_model(model_path, overrides) + return self._register(model_path, info) + + def install_path( + self, + model_path: Union[Path, str], + overrides: Optional[Dict[str, Any]] = None, + ) -> str: # noqa D102 + model_path = Path(model_path) + info: ModelProbeInfo = self._probe_model(model_path, overrides) + + dest_path = self._app_config.models_path / info.base_type.value / info.model_type.value / model_path.name + new_path = self._move_model(model_path, dest_path) + new_hash = self.hash(new_path) + assert new_hash == info.hash, f"{model_path}: Model hash changed during installation, possibly corrupted." + return self._register( + new_path, + info, + ) + + def unregister(self, key: str): # noqa D102 + self._store.del_model(key) + + def delete(self, key: str): # noqa D102 + model = self._store.get_model(key) + path = self._app_config.models_path / model.path + if path.is_dir(): + rmtree(path) + else: + path.unlink() + self.unregister(key) + + def conditionally_delete(self, key: str): # noqa D102 + """Unregister the model. Delete its files only if they are within our models directory.""" + model = self._store.get_model(key) + models_dir = self._app_config.models_path + model_path = models_dir / model.path + if model_path.is_relative_to(models_dir): + self.delete(key) + else: + self.unregister(key) + + def _register(self, model_path: Path, info: ModelProbeInfo) -> str: + key: str = FastModelHash.hash(model_path) + + model_path = model_path.absolute() + if model_path.is_relative_to(self._app_config.models_path): + model_path = model_path.relative_to(self._app_config.models_path) + + registration_data = dict( + path=model_path.as_posix(), + name=model_path.name if model_path.is_dir() else model_path.stem, + base_model=info.base_type, + model_type=info.model_type, + model_format=info.format, + hash=key, + ) + # add 'main' specific fields + if info.model_type == ModelType.Main: + if info.variant_type: + registration_data.update(variant=info.variant_type) + if info.format == ModelFormat.Checkpoint: + try: + config_file = self._legacy_configs[info.base_type][info.variant_type] + if isinstance(config_file, dict): # need another tier for sd-2.x models + if prediction_type := info.prediction_type: + config_file = config_file[prediction_type] + else: + self._logger.warning( + f"Could not infer prediction type for {model_path.stem}. Guessing 'v_prediction' for a SD-2 768 pixel model" + ) + config_file = config_file[SchedulerPredictionType.VPrediction] + registration_data.update( + config=Path(self._app_config.legacy_conf_dir, str(config_file)).as_posix(), + ) + except KeyError as exc: + raise InvalidModelException( + "Configuration file for this checkpoint could not be determined" + ) from exc + self._store.add_model(key, registration_data) + return key + + def _move_model(self, old_path: Path, new_path: Path) -> Path: + if old_path == new_path: + return old_path + + new_path.parent.mkdir(parents=True, exist_ok=True) + + # if path already exists then we jigger the name to make it unique + counter: int = 1 + while new_path.exists(): + path = new_path.with_stem(new_path.stem + f"_{counter:02d}") + if not path.exists(): + new_path = path + counter += 1 + return move(old_path, new_path) + + def _probe_model(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None) -> ModelProbeInfo: + info: ModelProbeInfo = ModelProbe.probe(Path(model_path)) + if overrides: # used to override probe fields + for key, value in overrides.items(): + try: + setattr(info, key, value) # skip validation errors + except Exception: + pass + return info + + def _complete_installation_handler(self, job: DownloadJobBase): + assert isinstance(job, ModelInstallJob) + if job.status == "completed": + self._logger.info(f"{job.source}: Download finished with status {job.status}. Installing.") + model_id = self.install_path(job.destination, job.probe_override) + info = self._store.get_model(model_id) + info.source = str(job.source) + if isinstance(job, DownloadJobWithMetadata): + metadata: ModelSourceMetadata = job.metadata + info.description = metadata.description or f"Imported model {info.name}" + info.name = metadata.name or info.name + info.author = metadata.author + info.tags = metadata.tags + info.license = metadata.license + info.thumbnail_url = metadata.thumbnail_url + self._store.update_model(model_id, info) + self._async_installs[job.source] = model_id + job.model_key = model_id + elif job.status == "error": + self._logger.warning(f"{job.source}: Model installation error: {job.error}") + elif job.status == "cancelled": + self._logger.warning(f"{job.source}: Model installation cancelled at caller's request.") + jobs = self._download_queue.list_jobs() + if self._tmpdir and len(jobs) <= 1 and job.status in ["completed", "error", "cancelled"]: + self._tmpdir.cleanup() + self._tmpdir = None + + def _complete_registration_handler(self, job: DownloadJobBase): + assert isinstance(job, ModelInstallJob) + if job.status == "completed": + self._logger.info(f"{job.source}: Installing in place.") + model_id = self.register_path(job.destination, job.probe_override) + info = self._store.get_model(model_id) + info.source = str(job.source) + info.description = f"Imported model {info.name}" + self._store.update_model(model_id, info) + self._async_installs[job.source] = model_id + job.model_key = model_id + elif job.status == "error": + self._logger.warning(f"{job.source}: Model installation error: {job.error}") + elif job.status == "cancelled": + self._logger.warning(f"{job.source}: Model installation cancelled at caller's request.") + + def sync_model_path(self, key: str, ignore_hash_change: bool = False) -> ModelConfigBase: + """ + Move model into the location indicated by its basetype, type and name. + + Call this after updating a model's attributes in order to move + the model's path into the location indicated by its basetype, type and + name. Applies only to models whose paths are within the root `models_dir` + directory. + + May raise an UnknownModelException. + """ + model = self._store.get_model(key) + old_path = Path(model.path) + models_dir = self._app_config.models_path + + if not old_path.is_relative_to(models_dir): + return model + + new_path = models_dir / model.base_model.value / model.model_type.value / model.name + self._logger.info(f"Moving {model.name} to {new_path}.") + new_path = self._move_model(old_path, new_path) + model.hash = self.hash(new_path) + model.path = new_path.relative_to(models_dir).as_posix() + if model.hash != key: + assert ( + ignore_hash_change + ), f"{model.name}: Model hash changed during installation, model is possibly corrupted" + self._logger.info(f"Model has new hash {model.hash}, but will continue to be identified by {key}") + self._store.update_model(key, model) + return model + + def _make_download_job( + self, + source: Union[str, Path, AnyHttpUrl], + variant: Optional[str] = None, + subfolder: Optional[str] = None, + access_token: Optional[str] = None, + priority: Optional[int] = 10, + ) -> ModelInstallJob: + # Clean up a common source of error. Doesn't work with Paths. + if isinstance(source, str): + source = source.strip() + + # In the event that we are being asked to install a path that is already on disk, + # we simply probe and register/install it. The job does not actually do anything, but we + # create one anyway in order to have similar behavior for local files, URLs and repo_ids. + if Path(source).exists(): # a path that is already on disk + destdir = source + return ModelInstallPathJob(source=source, destination=Path(destdir), event_handlers=self._handlers) + + # choose a temporary directory inside the models directory + models_dir = self._app_config.models_path + self._tmpdir = self._tmpdir or tempfile.TemporaryDirectory(dir=models_dir) + + cls = ModelInstallJob + if match := re.match(REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, str(source)): + cls = ModelInstallRepoIDJob + source = match.group(1) + subfolder = match.group(2) or subfolder + kwargs = dict(variant=variant, subfolder=subfolder) + elif re.match(HTTP_RE, str(source)): + cls = ModelInstallURLJob + kwargs = {} + else: + raise ValueError(f"'{source}' is not recognized as a local file, directory, repo_id or URL") + return cls( + source=str(source), + destination=Path(self._tmpdir.name), + access_token=access_token, + priority=priority, + event_handlers=self._handlers, + **kwargs, + ) + + def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]: + """Pause until all installation jobs have completed.""" + self._download_queue.join() + id_map = self._async_installs + self._async_installs = dict() + return id_map + + def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102 + self._cached_model_paths = set([Path(x.path) for x in self._store.all_models()]) + callback = self._scan_install if install else self._scan_register + search = ModelSearch(on_model_found=callback) + self._installed = set() + search.search(scan_dir) + return list(self._installed) + + def scan_models_directory(self): + """ + Scan the models directory for new and missing models. + + New models will be added to the storage backend. Missing models + will be deleted. + """ + defunct_models = set() + installed = set() + + with Chdir(self._app_config.models_path): + self._logger.info("Checking for models that have been moved or deleted from disk") + for model_config in self._store.all_models(): + path = Path(model_config.path) + if not path.exists(): + self._logger.info(f"{model_config.name}: path {path.as_posix()} no longer exists. Unregistering") + defunct_models.add(model_config.key) + for key in defunct_models: + self.unregister(key) + + self._logger.info(f"Scanning {self._app_config.models_path} for new models") + for cur_base_model in BaseModelType: + for cur_model_type in ModelType: + models_dir = Path(cur_base_model.value, cur_model_type.value) + installed.update(self.scan_directory(models_dir)) + self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered") + + def sync_to_config(self): + """Synchronize models on disk to those in memory.""" + self.scan_models_directory() + if autoimport := self._app_config.autoimport_dir: + self._logger.info("Scanning autoimport directory for new models") + self.scan_directory(self._app_config.root_path / autoimport) + + def hash(self, model_path: Union[Path, str]) -> str: # noqa D102 + return FastModelHash.hash(model_path) + + def _scan_register(self, model: Path) -> bool: + if model in self._cached_model_paths: + return True + try: + id = self.register_path(model) + self.sync_model_path(id) # possibly move it to right place in `models` + self._logger.info(f"Registered {model.name} with id {id}") + self._installed.add(id) + except DuplicateModelException: + pass + return True + + def _scan_install(self, model: Path) -> bool: + if model in self._cached_model_paths: + return True + try: + id = self.install_path(model) + self._logger.info(f"Installed {model} with id {id}") + self._installed.add(id) + except DuplicateModelException: + pass + return True diff --git a/invokeai/app/services/model_loader_service.py b/invokeai/app/services/model_loader_service.py new file mode 100644 index 00000000000..b06d97e8114 --- /dev/null +++ b/invokeai/app/services/model_loader_service.py @@ -0,0 +1,140 @@ +# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team + +from __future__ import annotations + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, Optional, Union + +from pydantic import Field + +from invokeai.app.models.exceptions import CanceledException +from invokeai.backend.model_manager import ModelConfigStore, SubModelType +from invokeai.backend.model_manager.cache import CacheStats +from invokeai.backend.model_manager.loader import ModelConfigBase, ModelInfo, ModelLoad + +from .config import InvokeAIAppConfig +from .model_record_service import ModelRecordServiceBase + +if TYPE_CHECKING: + from ..invocations.baseinvocation import InvocationContext + + +class ModelLoadServiceBase(ABC): + """Load models into memory.""" + + @abstractmethod + def __init__( + self, + config: InvokeAIAppConfig, + store: Union[ModelConfigStore, ModelRecordServiceBase], + ): + """ + Initialize a ModelLoadService + + :param config: InvokeAIAppConfig object + :param store: ModelConfigStore object for fetching configuration information + installation and download events will be sent to the event bus. + """ + pass + + @abstractmethod + def get_model( + self, + key: str, + submodel_type: Optional[SubModelType] = None, + context: Optional[InvocationContext] = None, + ) -> ModelInfo: + """Retrieve the indicated model identified by key. + + :param key: Unique key returned by the ModelConfigStore module. + :param submodel_type: Submodel to return (required for main models) + :param context" Optional InvocationContext, used in event reporting. + """ + pass + + @abstractmethod + def collect_cache_stats(self, cache_stats: CacheStats): + """Reset model cache statistics for graph with graph_id.""" + pass + + +# implementation +class ModelLoadService(ModelLoadServiceBase): + """Responsible for managing models on disk and in memory.""" + + _loader: ModelLoad + + def __init__( + self, + config: InvokeAIAppConfig, + record_store: Union[ModelConfigStore, ModelRecordServiceBase], + ): + """ + Initialize a ModelLoadService. + + :param config: InvokeAIAppConfig object + :param store: ModelRecordServiceBase or ModelConfigStore object for fetching configuration information + installation and download events will be sent to the event bus. + """ + self._loader = ModelLoad(config, record_store) + + def get_model( + self, + key: str, + submodel_type: Optional[SubModelType] = None, + context: Optional[InvocationContext] = None, + ) -> ModelInfo: + """ + Retrieve the indicated model. + + The submodel is required when fetching a main model. + """ + model_info: ModelInfo = self._loader.get_model(key, submodel_type) + + # we can emit model loading events if we are executing with access to the invocation context + if context: + self._emit_load_event( + context=context, + model_key=key, + submodel=submodel_type, + model_info=model_info, + ) + + return model_info + + def collect_cache_stats(self, cache_stats: CacheStats): + """ + Reset model cache statistics. Is this used? + """ + self._loader.collect_cache_stats(cache_stats) + + def _emit_load_event( + self, + context: InvocationContext, + model_key: str, + submodel: Optional[SubModelType] = None, + model_info: Optional[ModelInfo] = None, + ): + if context.services.queue.is_canceled(context.graph_execution_state_id): + raise CanceledException() + + if model_info: + context.services.events.emit_model_load_completed( + queue_id=context.queue_id, + queue_item_id=context.queue_item_id, + queue_batch_id=context.queue_batch_id, + graph_execution_state_id=context.graph_execution_state_id, + model_key=model_key, + submodel=submodel, + model_info=model_info, + ) + else: + context.services.events.emit_model_load_started( + queue_id=context.queue_id, + queue_item_id=context.queue_item_id, + queue_batch_id=context.queue_batch_id, + graph_execution_state_id=context.graph_execution_state_id, + model_key=model_key, + submodel=submodel, + ) diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py deleted file mode 100644 index 143fa8f3571..00000000000 --- a/invokeai/app/services/model_manager_service.py +++ /dev/null @@ -1,675 +0,0 @@ -# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team - -from __future__ import annotations - -from abc import ABC, abstractmethod -from logging import Logger -from pathlib import Path -from types import ModuleType -from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union - -import torch -from pydantic import Field - -from invokeai.app.models.exceptions import CanceledException -from invokeai.backend.model_management import ( - AddModelResult, - BaseModelType, - MergeInterpolationMethod, - ModelInfo, - ModelManager, - ModelMerger, - ModelNotFoundException, - ModelType, - SchedulerPredictionType, - SubModelType, -) -from invokeai.backend.model_management.model_cache import CacheStats -from invokeai.backend.model_management.model_search import FindModels - -from ...backend.util import choose_precision, choose_torch_device -from .config import InvokeAIAppConfig - -if TYPE_CHECKING: - from ..invocations.baseinvocation import BaseInvocation, InvocationContext - - -class ModelManagerServiceBase(ABC): - """Responsible for managing models on disk and in memory""" - - @abstractmethod - def __init__( - self, - config: InvokeAIAppConfig, - logger: ModuleType, - ): - """ - Initialize with the path to the models.yaml config file. - Optional parameters are the torch device type, precision, max_models, - and sequential_offload boolean. Note that the default device - type and precision are set up for a CUDA system running at half precision. - """ - pass - - @abstractmethod - def get_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: Optional[SubModelType] = None, - node: Optional[BaseInvocation] = None, - context: Optional[InvocationContext] = None, - ) -> ModelInfo: - """Retrieve the indicated model with name and type. - submodel can be used to get a part (such as the vae) - of a diffusers pipeline.""" - pass - - @property - @abstractmethod - def logger(self): - pass - - @abstractmethod - def model_exists( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - ) -> bool: - pass - - @abstractmethod - def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: - """ - Given a model name returns a dict-like (OmegaConf) object describing it. - Uses the exact format as the omegaconf stanza. - """ - pass - - @abstractmethod - def list_models(self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None) -> dict: - """ - Return a dict of models in the format: - { model_type1: - { model_name1: {'status': 'active'|'cached'|'not loaded', - 'model_name' : name, - 'model_type' : SDModelType, - 'description': description, - 'format': 'folder'|'safetensors'|'ckpt' - }, - model_name2: { etc } - }, - model_type2: - { model_name_n: etc - } - """ - pass - - @abstractmethod - def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: - """ - Return information about the model using the same format as list_models() - """ - pass - - @abstractmethod - def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]: - """ - Returns a list of all the model names known. - """ - pass - - @abstractmethod - def add_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - model_attributes: dict, - clobber: bool = False, - ) -> AddModelResult: - """ - Update the named model with a dictionary of attributes. Will fail with an - assertion error if the name already exists. Pass clobber=True to overwrite. - On a successful update, the config will be changed in memory. Will fail - with an assertion error if provided attributes are incorrect or - the model name is missing. Call commit() to write changes to disk. - """ - pass - - @abstractmethod - def update_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - model_attributes: dict, - ) -> AddModelResult: - """ - Update the named model with a dictionary of attributes. Will fail with a - ModelNotFoundException if the name does not already exist. - - On a successful update, the config will be changed in memory. Will fail - with an assertion error if provided attributes are incorrect or - the model name is missing. Call commit() to write changes to disk. - """ - pass - - @abstractmethod - def del_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - ): - """ - Delete the named model from configuration. If delete_files is true, - then the underlying weight file or diffusers directory will be deleted - as well. Call commit() to write to disk. - """ - pass - - @abstractmethod - def rename_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - new_name: str, - ): - """ - Rename the indicated model. - """ - pass - - @abstractmethod - def list_checkpoint_configs(self) -> List[Path]: - """ - List the checkpoint config paths from ROOT/configs/stable-diffusion. - """ - pass - - @abstractmethod - def convert_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: Literal[ModelType.Main, ModelType.Vae], - ) -> AddModelResult: - """ - Convert a checkpoint file into a diffusers folder, deleting the cached - version and deleting the original checkpoint file if it is in the models - directory. - :param model_name: Name of the model to convert - :param base_model: Base model type - :param model_type: Type of model ['vae' or 'main'] - - This will raise a ValueError unless the model is not a checkpoint. It will - also raise a ValueError in the event that there is a similarly-named diffusers - directory already in place. - """ - pass - - @abstractmethod - def heuristic_import( - self, - items_to_import: set[str], - prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None, - ) -> dict[str, AddModelResult]: - """Import a list of paths, repo_ids or URLs. Returns the set of - successfully imported items. - :param items_to_import: Set of strings corresponding to models to be imported. - :param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType. - - The prediction type helper is necessary to distinguish between - models based on Stable Diffusion 2 Base (requiring - SchedulerPredictionType.Epsilson) and Stable Diffusion 768 - (requiring SchedulerPredictionType.VPrediction). It is - generally impossible to do this programmatically, so the - prediction_type_helper usually asks the user to choose. - - The result is a set of successfully installed models. Each element - of the set is a dict corresponding to the newly-created OmegaConf stanza for - that model. - """ - pass - - @abstractmethod - def merge_models( - self, - model_names: List[str] = Field( - default=None, min_items=2, max_items=3, description="List of model names to merge" - ), - base_model: Union[BaseModelType, str] = Field( - default=None, description="Base model shared by all models to be merged" - ), - merged_model_name: str = Field(default=None, description="Name of destination model after merging"), - alpha: Optional[float] = 0.5, - interp: Optional[MergeInterpolationMethod] = None, - force: Optional[bool] = False, - merge_dest_directory: Optional[Path] = None, - ) -> AddModelResult: - """ - Merge two to three diffusrs pipeline models and save as a new model. - :param model_names: List of 2-3 models to merge - :param base_model: Base model to use for all models - :param merged_model_name: Name of destination merged model - :param alpha: Alpha strength to apply to 2d and 3d model - :param interp: Interpolation method. None (default) - :param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended) - """ - pass - - @abstractmethod - def search_for_models(self, directory: Path) -> List[Path]: - """ - Return list of all models found in the designated directory. - """ - pass - - @abstractmethod - def sync_to_config(self): - """ - Re-read models.yaml, rescan the models directory, and reimport models - in the autoimport directories. Call after making changes outside the - model manager API. - """ - pass - - @abstractmethod - def collect_cache_stats(self, cache_stats: CacheStats): - """ - Reset model cache statistics for graph with graph_id. - """ - pass - - @abstractmethod - def commit(self, conf_file: Optional[Path] = None) -> None: - """ - Write current configuration out to the indicated file. - If no conf_file is provided, then replaces the - original file/database used to initialize the object. - """ - pass - - -# simple implementation -class ModelManagerService(ModelManagerServiceBase): - """Responsible for managing models on disk and in memory""" - - def __init__( - self, - config: InvokeAIAppConfig, - logger: Logger, - ): - """ - Initialize with the path to the models.yaml config file. - Optional parameters are the torch device type, precision, max_models, - and sequential_offload boolean. Note that the default device - type and precision are set up for a CUDA system running at half precision. - """ - if config.model_conf_path and config.model_conf_path.exists(): - config_file = config.model_conf_path - else: - config_file = config.root_dir / "configs/models.yaml" - - logger.debug(f"Config file={config_file}") - - device = torch.device(choose_torch_device()) - device_name = torch.cuda.get_device_name() if device == torch.device("cuda") else "" - logger.info(f"GPU device = {device} {device_name}") - - precision = config.precision - if precision == "auto": - precision = choose_precision(device) - dtype = torch.float32 if precision == "float32" else torch.float16 - - # this is transitional backward compatibility - # support for the deprecated `max_loaded_models` - # configuration value. If present, then the - # cache size is set to 2.5 GB times - # the number of max_loaded_models. Otherwise - # use new `ram_cache_size` config setting - max_cache_size = config.ram_cache_size - - logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB") - - sequential_offload = config.sequential_guidance - - self.mgr = ModelManager( - config=config_file, - device_type=device, - precision=dtype, - max_cache_size=max_cache_size, - sequential_offload=sequential_offload, - logger=logger, - ) - logger.info("Model manager service initialized") - - def get_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: Optional[SubModelType] = None, - context: Optional[InvocationContext] = None, - ) -> ModelInfo: - """ - Retrieve the indicated model. submodel can be used to get a - part (such as the vae) of a diffusers mode. - """ - - # we can emit model loading events if we are executing with access to the invocation context - if context: - self._emit_load_event( - context=context, - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=submodel, - ) - - model_info = self.mgr.get_model( - model_name, - base_model, - model_type, - submodel, - ) - - if context: - self._emit_load_event( - context=context, - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=submodel, - model_info=model_info, - ) - - return model_info - - def model_exists( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - ) -> bool: - """ - Given a model name, returns True if it is a valid - identifier. - """ - return self.mgr.model_exists( - model_name, - base_model, - model_type, - ) - - def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]: - """ - Given a model name returns a dict-like (OmegaConf) object describing it. - """ - return self.mgr.model_info(model_name, base_model, model_type) - - def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]: - """ - Returns a list of all the model names known. - """ - return self.mgr.model_names() - - def list_models( - self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None - ) -> list[dict]: - """ - Return a list of models. - """ - return self.mgr.list_models(base_model, model_type) - - def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]: - """ - Return information about the model using the same format as list_models() - """ - return self.mgr.list_model(model_name=model_name, base_model=base_model, model_type=model_type) - - def add_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - model_attributes: dict, - clobber: bool = False, - ) -> AddModelResult: - """ - Update the named model with a dictionary of attributes. Will fail with an - assertion error if the name already exists. Pass clobber=True to overwrite. - On a successful update, the config will be changed in memory. Will fail - with an assertion error if provided attributes are incorrect or - the model name is missing. Call commit() to write changes to disk. - """ - self.logger.debug(f"add/update model {model_name}") - return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber) - - def update_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - model_attributes: dict, - ) -> AddModelResult: - """ - Update the named model with a dictionary of attributes. Will fail with a - ModelNotFoundException exception if the name does not already exist. - On a successful update, the config will be changed in memory. Will fail - with an assertion error if provided attributes are incorrect or - the model name is missing. Call commit() to write changes to disk. - """ - self.logger.debug(f"update model {model_name}") - if not self.model_exists(model_name, base_model, model_type): - raise ModelNotFoundException(f"Unknown model {model_name}") - return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True) - - def del_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - ): - """ - Delete the named model from configuration. If delete_files is true, - then the underlying weight file or diffusers directory will be deleted - as well. - """ - self.logger.debug(f"delete model {model_name}") - self.mgr.del_model(model_name, base_model, model_type) - self.mgr.commit() - - def convert_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: Literal[ModelType.Main, ModelType.Vae], - convert_dest_directory: Optional[Path] = Field( - default=None, description="Optional directory location for merged model" - ), - ) -> AddModelResult: - """ - Convert a checkpoint file into a diffusers folder, deleting the cached - version and deleting the original checkpoint file if it is in the models - directory. - :param model_name: Name of the model to convert - :param base_model: Base model type - :param model_type: Type of model ['vae' or 'main'] - :param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default) - - This will raise a ValueError unless the model is not a checkpoint. It will - also raise a ValueError in the event that there is a similarly-named diffusers - directory already in place. - """ - self.logger.debug(f"convert model {model_name}") - return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory) - - def collect_cache_stats(self, cache_stats: CacheStats): - """ - Reset model cache statistics for graph with graph_id. - """ - self.mgr.cache.stats = cache_stats - - def commit(self, conf_file: Optional[Path] = None): - """ - Write current configuration out to the indicated file. - If no conf_file is provided, then replaces the - original file/database used to initialize the object. - """ - return self.mgr.commit(conf_file) - - def _emit_load_event( - self, - context: InvocationContext, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: Optional[SubModelType] = None, - model_info: Optional[ModelInfo] = None, - ): - if context.services.queue.is_canceled(context.graph_execution_state_id): - raise CanceledException() - - if model_info: - context.services.events.emit_model_load_completed( - queue_id=context.queue_id, - queue_item_id=context.queue_item_id, - queue_batch_id=context.queue_batch_id, - graph_execution_state_id=context.graph_execution_state_id, - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=submodel, - model_info=model_info, - ) - else: - context.services.events.emit_model_load_started( - queue_id=context.queue_id, - queue_item_id=context.queue_item_id, - queue_batch_id=context.queue_batch_id, - graph_execution_state_id=context.graph_execution_state_id, - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=submodel, - ) - - @property - def logger(self): - return self.mgr.logger - - def heuristic_import( - self, - items_to_import: set[str], - prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None, - ) -> dict[str, AddModelResult]: - """Import a list of paths, repo_ids or URLs. Returns the set of - successfully imported items. - :param items_to_import: Set of strings corresponding to models to be imported. - :param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType. - - The prediction type helper is necessary to distinguish between - models based on Stable Diffusion 2 Base (requiring - SchedulerPredictionType.Epsilson) and Stable Diffusion 768 - (requiring SchedulerPredictionType.VPrediction). It is - generally impossible to do this programmatically, so the - prediction_type_helper usually asks the user to choose. - - The result is a set of successfully installed models. Each element - of the set is a dict corresponding to the newly-created OmegaConf stanza for - that model. - """ - return self.mgr.heuristic_import(items_to_import, prediction_type_helper) - - def merge_models( - self, - model_names: List[str] = Field( - default=None, min_items=2, max_items=3, description="List of model names to merge" - ), - base_model: Union[BaseModelType, str] = Field( - default=None, description="Base model shared by all models to be merged" - ), - merged_model_name: str = Field(default=None, description="Name of destination model after merging"), - alpha: float = 0.5, - interp: Optional[MergeInterpolationMethod] = None, - force: bool = False, - merge_dest_directory: Optional[Path] = Field( - default=None, description="Optional directory location for merged model" - ), - ) -> AddModelResult: - """ - Merge two to three diffusrs pipeline models and save as a new model. - :param model_names: List of 2-3 models to merge - :param base_model: Base model to use for all models - :param merged_model_name: Name of destination merged model - :param alpha: Alpha strength to apply to 2d and 3d model - :param interp: Interpolation method. None (default) - :param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended) - """ - merger = ModelMerger(self.mgr) - try: - result = merger.merge_diffusion_models_and_save( - model_names=model_names, - base_model=base_model, - merged_model_name=merged_model_name, - alpha=alpha, - interp=interp, - force=force, - merge_dest_directory=merge_dest_directory, - ) - except AssertionError as e: - raise ValueError(e) - return result - - def search_for_models(self, directory: Path) -> List[Path]: - """ - Return list of all models found in the designated directory. - """ - search = FindModels([directory], self.logger) - return search.list_models() - - def sync_to_config(self): - """ - Re-read models.yaml, rescan the models directory, and reimport models - in the autoimport directories. Call after making changes outside the - model manager API. - """ - return self.mgr.sync_to_config() - - def list_checkpoint_configs(self) -> List[Path]: - """ - List the checkpoint config paths from ROOT/configs/stable-diffusion. - """ - config = self.mgr.app_config - conf_path = config.legacy_conf_path - root_path = config.root_path - return [(conf_path / x).relative_to(root_path) for x in conf_path.glob("**/*.yaml")] - - def rename_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - new_name: Optional[str] = None, - new_base: Optional[BaseModelType] = None, - ): - """ - Rename the indicated model. Can provide a new name and/or a new base. - :param model_name: Current name of the model - :param base_model: Current base of the model - :param model_type: Model type (can't be changed) - :param new_name: New name for the model - :param new_base: New base for the model - """ - self.mgr.rename_model( - base_model=base_model, - model_type=model_type, - model_name=model_name, - new_name=new_name, - new_base=new_base, - ) diff --git a/invokeai/app/services/model_record_service.py b/invokeai/app/services/model_record_service.py new file mode 100644 index 00000000000..a5779727b52 --- /dev/null +++ b/invokeai/app/services/model_record_service.py @@ -0,0 +1,130 @@ +# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team + +from __future__ import annotations + +import sqlite3 +import threading +from abc import abstractmethod +from pathlib import Path +from typing import Optional, Union + +from invokeai.backend.model_manager import ( # noqa F401 + BaseModelType, + ModelConfigBase, + ModelFormat, + ModelType, + ModelVariantType, + SchedulerPredictionType, + SubModelType, +) +from invokeai.backend.model_manager.storage import ( # noqa F401 + ModelConfigStore, + ModelConfigStoreSQL, + ModelConfigStoreYAML, + UnknownModelException, +) +from invokeai.backend.util.logging import InvokeAILogger + +from .config import InvokeAIAppConfig + + +class ModelRecordServiceBase(ModelConfigStore): + """ + Responsible for managing model configuration records. + + This is an ABC that is simply a subclassing of the ModelConfigStore ABC + in the backend. + """ + + @classmethod + @abstractmethod + def from_db_file(cls, db_file: Path) -> ModelRecordServiceBase: + """ + Initialize a new object from a database file. + + If the path does not exist, a new sqlite3 db will be initialized. + + :param db_file: Path to the database file + """ + pass + + @classmethod + def open( + cls, config: InvokeAIAppConfig, conn: Optional[sqlite3.Connection] = None, lock: Optional[threading.Lock] = None + ) -> Union[ModelRecordServiceSQL, ModelRecordServiceFile]: + """ + Choose either a ModelConfigStoreSQL or a ModelConfigStoreFile backend. + + Logic is as follows: + 1. if config.model_config_db contains a Path, then + a. if the path looks like a .db file, open a new sqlite3 connection and return a ModelRecordServiceSQL + b. if the path looks like a .yaml file, return a new ModelRecordServiceFile + c. otherwise bail + 2. if config.model_config_db is the literal 'auto', then use the passed sqlite3 connection and thread lock. + a. if either of these is missing, then we create our own connection to the invokeai.db file, which *should* + be a safe thing to do - sqlite3 will use file-level locking. + 3. if config.model_config_db is None, then fall back to config.conf_path, using a yaml file + """ + logger = InvokeAILogger.get_logger() + db = config.model_config_db + if db is None: + return ModelRecordServiceFile.from_db_file(config.model_conf_path) + if str(db) == "auto": + logger.info("Model config storage = main InvokeAI database") + return ( + ModelRecordServiceSQL.from_connection(conn, lock) + if (conn and lock) + else ModelRecordServiceSQL.from_db_file(config.db_path) + ) + assert isinstance(db, Path) + suffix = db.suffix + if suffix == ".yaml": + logger.info(f"Model config storage = {str(db)}") + return ModelRecordServiceFile.from_db_file(config.root_path / db) + elif suffix == ".db": + logger.info(f"Model config storage = {str(db)}") + return ModelRecordServiceSQL.from_db_file(config.root_path / db) + else: + raise ValueError( + f'Unrecognized model config record db file type {db} in "model_config_db" configuration variable.' + ) + + +class ModelRecordServiceSQL(ModelConfigStoreSQL): + """ + ModelRecordService that uses Sqlite for its backend. + Please see invokeai/backend/model_manager/storage/sql.py for + the implementation. + """ + + @classmethod + def from_connection(cls, conn: sqlite3.Connection, lock: threading.Lock) -> ModelRecordServiceSQL: + """ + Initialize a new object from preexisting sqlite3 connection and threading lock objects. + + This is the same as the default __init__() constructor. + + :param conn: sqlite3 connection object + :param lock: threading Lock object + """ + return cls(conn, lock) + + @classmethod + def from_db_file(cls, db_file: Path) -> ModelRecordServiceSQL: # noqa D102 - docstring in ABC + Path(db_file).parent.mkdir(parents=True, exist_ok=True) + conn = sqlite3.connect(db_file, check_same_thread=False) + lock = threading.Lock() + return cls(conn, lock) + + +class ModelRecordServiceFile(ModelConfigStoreYAML): + """ + ModelRecordService that uses a YAML file for its backend. + + Please see invokeai/backend/model_manager/storage/yaml.py for + the implementation. + """ + + @classmethod + def from_db_file(cls, db_file: Path) -> ModelRecordServiceFile: # noqa D102 - docstring in ABC + return cls(db_file) diff --git a/invokeai/app/services/processor.py b/invokeai/app/services/processor.py index b4c894c52d7..37a1612ad33 100644 --- a/invokeai/app/services/processor.py +++ b/invokeai/app/services/processor.py @@ -97,8 +97,8 @@ def __process(self, stop_event: Event): # Invoke try: graph_id = graph_execution_state.id - model_manager = self.__invoker.services.model_manager - with statistics.collect_stats(invocation, graph_id, model_manager): + model_loader = self.__invoker.services.model_loader + with statistics.collect_stats(invocation, graph_id, model_loader): # use the internal invoke_internal(), which wraps the node's invoke() method, # which handles a few things: # - nodes that require a value, but get it only from a connection diff --git a/invokeai/app/util/step_callback.py b/invokeai/app/util/step_callback.py index 6d4a857491b..d5bf21a5c58 100644 --- a/invokeai/app/util/step_callback.py +++ b/invokeai/app/util/step_callback.py @@ -4,7 +4,7 @@ from invokeai.app.models.exceptions import CanceledException from invokeai.app.models.image import ProgressImage -from ...backend.model_management.models import BaseModelType +from ...backend.model_manager import BaseModelType from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.util.util import image_to_dataURL from ..invocations.baseinvocation import InvocationContext diff --git a/invokeai/backend/__init__.py b/invokeai/backend/__init__.py index ae9a12edbe2..420b90d7b44 100644 --- a/invokeai/backend/__init__.py +++ b/invokeai/backend/__init__.py @@ -1,5 +1,15 @@ """ Initialization file for invokeai.backend """ -from .model_management import BaseModelType, ModelCache, ModelInfo, ModelManager, ModelType, SubModelType # noqa: F401 -from .model_management.models import SilenceWarnings # noqa: F401 +from .model_manager import ( # noqa F401 + BaseModelType, + DuplicateModelException, + InvalidModelException, + ModelConfigStore, + ModelType, + ModelVariantType, + SchedulerPredictionType, + SilenceWarnings, + SubModelType, +) +from .util.devices import get_precision # noqa F401 diff --git a/invokeai/backend/install/check_root.py b/invokeai/backend/install/check_root.py index 6ee2aa34b7d..b35c2f4ea34 100644 --- a/invokeai/backend/install/check_root.py +++ b/invokeai/backend/install/check_root.py @@ -8,7 +8,7 @@ def check_invokeai_root(config: InvokeAIAppConfig): try: - assert config.model_conf_path.exists(), f"{config.model_conf_path} not found" + assert config.model_conf_path.parent.exists(), f"{config.model_conf_path.parent} not found" assert config.db_path.parent.exists(), f"{config.db_path.parent} not found" assert config.models_path.exists(), f"{config.models_path} not found" if not config.ignore_missing_core_models: diff --git a/invokeai/backend/install/install_helper.py b/invokeai/backend/install/install_helper.py new file mode 100644 index 00000000000..7404c4ca2ea --- /dev/null +++ b/invokeai/backend/install/install_helper.py @@ -0,0 +1,196 @@ +""" +Utility (backend) functions used by model_install.py +""" +from pathlib import Path +from typing import Dict, List, Optional + +import omegaconf +from huggingface_hub import HfFolder +from pydantic import BaseModel, Field +from pydantic.dataclasses import dataclass +from tqdm import tqdm + +import invokeai.configs as configs +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.model_install_service import ModelInstallJob, ModelInstallService, ModelSourceMetadata +from invokeai.backend.model_manager import BaseModelType, ModelType +from invokeai.backend.model_manager.download.queue import DownloadJobRemoteSource + +# name of the starter models file +INITIAL_MODELS = "INITIAL_MODELS.yaml" + + +class UnifiedModelInfo(BaseModel): + name: Optional[str] = None + base_model: Optional[BaseModelType] = None + model_type: Optional[ModelType] = None + source: Optional[str] = None + subfolder: Optional[str] = None + description: Optional[str] = None + recommended: bool = False + installed: bool = False + default: bool = False + requires: List[str] = Field(default_factory=list) + + +@dataclass +class InstallSelections: + install_models: List[UnifiedModelInfo] = Field(default_factory=list) + remove_models: List[str] = Field(default_factory=list) + + +class TqdmProgress(object): + _bars: Dict[int, tqdm] # the tqdm object + _last: Dict[int, int] # last bytes downloaded + + def __init__(self): + self._bars = dict() + self._last = dict() + + def job_update(self, job: ModelInstallJob): + if not isinstance(job, DownloadJobRemoteSource): + return + job_id = job.id + if job.status == "running" and job.total_bytes > 0: # job starts running before total bytes known + if job_id not in self._bars: + dest = Path(job.destination).name + self._bars[job_id] = tqdm( + desc=dest, + initial=0, + total=job.total_bytes, + unit="iB", + unit_scale=True, + ) + self._last[job_id] = 0 + self._bars[job_id].update(job.bytes - self._last[job_id]) + self._last[job_id] = job.bytes + + +class InstallHelper(object): + """Capture information stored jointly in INITIAL_MODELS.yaml and the installed models db.""" + + all_models: Dict[str, UnifiedModelInfo] = dict() + _installer: ModelInstallService + _config: InvokeAIAppConfig + _installed_models: List[str] = [] + _starter_models: List[str] = [] + _default_model: Optional[str] = None + _initial_models: omegaconf.DictConfig + + def __init__(self, config: InvokeAIAppConfig): + self._config = config + self._installer = ModelInstallService(config=config, event_handlers=[TqdmProgress().job_update]) + self._initial_models = omegaconf.OmegaConf.load(Path(configs.__path__[0]) / INITIAL_MODELS) + self._initialize_model_lists() + + @property + def installer(self) -> ModelInstallService: + return self._installer + + def _initialize_model_lists(self): + """ + Initialize our model slots. + + Set up the following: + installed_models -- list of installed model keys + starter_models -- list of starter model keys from INITIAL_MODELS + all_models -- dict of key => UnifiedModelInfo + default_model -- key to default model + """ + # previously-installed models + for model in self._installer.store.all_models(): + info = UnifiedModelInfo.parse_obj(model.dict()) + info.installed = True + key = f"{model.base_model.value}/{model.model_type.value}/{model.name}" + self.all_models[key] = info + self._installed_models.append(key) + + for key in self._initial_models.keys(): + if key in self.all_models: + # we want to preserve the description + description = self.all_models[key].description or self._initial_models[key].get("description") + self.all_models[key].description = description + else: + base_model, model_type, model_name = key.split("/") + info = UnifiedModelInfo( + name=model_name, + model_type=model_type, + base_model=base_model, + source=self._initial_models[key].source, + description=self._initial_models[key].get("description"), + recommended=self._initial_models[key].get("recommended", False), + default=self._initial_models[key].get("default", False), + subfolder=self._initial_models[key].get("subfolder"), + requires=list(self._initial_models[key].get("requires", [])), + ) + self.all_models[key] = info + if not self.default_model: + self._default_model = key + elif self._initial_models[key].get("default", False): + self._default_model = key + self._starter_models.append(key) + + # previously-installed models + for model in self._installer.store.all_models(): + info = UnifiedModelInfo.parse_obj(model.dict()) + info.installed = True + key = f"{model.base_model.value}/{model.model_type.value}/{model.name}" + self.all_models[key] = info + self._installed_models.append(key) + + def recommended_models(self) -> List[UnifiedModelInfo]: + return [self._to_model(x) for x in self._starter_models if self._to_model(x).recommended] + + def installed_models(self) -> List[UnifiedModelInfo]: + return [self._to_model(x) for x in self._installed_models] + + def starter_models(self) -> List[UnifiedModelInfo]: + return [self._to_model(x) for x in self._starter_models] + + def default_model(self) -> UnifiedModelInfo: + return self._to_model(self._default_model) + + def _to_model(self, key: str) -> UnifiedModelInfo: + return self.all_models[key] + + def _add_required_models(self, model_list: List[UnifiedModelInfo]): + installed = {x.source for x in self.installed_models()} + reverse_source = {x.source: x for x in self.all_models.values()} + additional_models = [] + for model_info in model_list: + for requirement in model_info.requires: + if requirement not in installed: + additional_models.append(reverse_source.get(requirement)) + model_list.extend(additional_models) + + def add_or_delete(self, selections: InstallSelections): + installer = self._installer + self._add_required_models(selections.install_models) + for model in selections.install_models: + metadata = ModelSourceMetadata(description=model.description, name=model.name) + installer.install_model( + model.source, + subfolder=model.subfolder, + access_token=HfFolder.get_token(), + metadata=metadata, + ) + + for model in selections.remove_models: + parts = model.split("/") + if len(parts) == 1: + base_model, model_type, model_name = (None, None, model) + else: + base_model, model_type, model_name = parts + matches = installer.store.search_by_name( + base_model=base_model, model_type=model_type, model_name=model_name + ) + if len(matches) > 1: + print(f"{model} is ambiguous. Please use model_type:model_name (e.g. main:my_model) to disambiguate.") + elif not matches: + print(f"{model}: unknown model") + else: + for m in matches: + print(f"Deleting {m.model_type}:{m.name}") + installer.conditionally_delete(m.key) + + installer.wait_for_installs() diff --git a/invokeai/backend/install/invokeai_configure.py b/invokeai/backend/install/invokeai_configure.py index 5afbdfb5a3d..3029b27a8b7 100755 --- a/invokeai/backend/install/invokeai_configure.py +++ b/invokeai/backend/install/invokeai_configure.py @@ -22,7 +22,6 @@ from urllib import request import npyscreen -import omegaconf import psutil import torch import transformers @@ -38,21 +37,25 @@ import invokeai.configs as configs from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend.install.install_helper import InstallHelper, InstallSelections from invokeai.backend.install.legacy_arg_parsing import legacy_parser -from invokeai.backend.install.model_install_backend import InstallSelections, ModelInstall, hf_download_from_pretrained -from invokeai.backend.model_management.model_probe import BaseModelType, ModelType +from invokeai.backend.model_manager import BaseModelType, ModelType +from invokeai.backend.model_manager.storage import ConfigFileVersionMismatchException, migrate_models_store +from invokeai.backend.util import choose_precision, choose_torch_device from invokeai.backend.util.logging import InvokeAILogger -from invokeai.frontend.install.model_install import addModelsForm, process_and_execute +from invokeai.frontend.install.model_install import addModelsForm # TO DO - Move all the frontend code into invokeai.frontend.install from invokeai.frontend.install.widgets import ( MIN_COLS, MIN_LINES, CenteredButtonPress, + CheckboxWithChanged, CyclingForm, FileBox, MultiSelectColumns, SingleSelectColumnsSimple, + SingleSelectWithChanged, WindowTooSmallException, set_min_terminal_size, ) @@ -82,7 +85,6 @@ def get_literal_fields(field) -> list[Any]: HAS_CUDA = torch.cuda.is_available() _, MAX_VRAM = torch.cuda.mem_get_info() if HAS_CUDA else (0, 0) - MAX_VRAM /= GB MAX_RAM = psutil.virtual_memory().total / GB @@ -96,6 +98,8 @@ def get_literal_fields(field) -> list[Any]: class DummyWidgetValue(Enum): + """Dummy widget values.""" + zero = 0 true = True false = False @@ -179,6 +183,22 @@ def __call__(self, block_num, block_size, total_size): self.pbar.update(block_size) +# --------------------------------------------- +def hf_download_from_pretrained(model_class: object, model_name: str, destination: Path, **kwargs): + filter = lambda x: "fp16 is not a valid" not in x.getMessage() + logger.addFilter(filter) + try: + model = model_class.from_pretrained( + model_name, + resume_download=True, + **kwargs, + ) + model.save_pretrained(destination, safe_serialization=True) + finally: + logger.removeFilter(filter) + return destination + + # --------------------------------------------- def download_with_progress_bar(model_url: str, model_dest: str, label: str = "the"): try: @@ -455,6 +475,25 @@ def create(self): max_width=110, scroll_exit=True, ) + self.add_widget_intelligent( + npyscreen.TitleFixedText, + name="Model disk conversion cache size (GB). This is used to cache safetensors files that need to be converted to diffusers..", + begin_entry_at=0, + editable=False, + color="CONTROL", + scroll_exit=True, + ) + self.nextrely -= 1 + self.disk = self.add_widget_intelligent( + npyscreen.Slider, + value=clip(old_opts.disk, range=(0, 100), step=0.5), + out_of=100, + lowest=0.0, + step=0.5, + relx=8, + scroll_exit=True, + ) + self.nextrely += 1 self.add_widget_intelligent( npyscreen.TitleFixedText, name="Model RAM cache size (GB). Make this at least large enough to hold a single full model (2GB for SD-1, 6GB for SDXL).", @@ -495,6 +534,45 @@ def create(self): ) else: self.vram = DummyWidgetValue.zero + + self.nextrely += 1 + self.add_widget_intelligent( + npyscreen.FixedText, + value="Location of the database used to store model path and configuration information:", + editable=False, + color="CONTROL", + ) + self.nextrely += 1 + if first_time: + old_opts.model_config_db = "auto" + self.model_conf_auto = self.add_widget_intelligent( + CheckboxWithChanged, + value=str(old_opts.model_config_db) == "auto", + name="Main database", + relx=2, + max_width=25, + scroll_exit=True, + ) + self.nextrely -= 2 + config_db = str(old_opts.model_config_db or old_opts.conf_path) + self.model_conf_override = self.add_widget_intelligent( + FileBox, + value=str(old_opts.root_path / config_db) + if config_db != "auto" + else str(old_opts.root_path / old_opts.conf_path), + name="Specify models config database manually", + select_dir=False, + must_exist=False, + use_two_lines=False, + labelColor="GOOD", + # begin_entry_at=40, + relx=30, + max_height=3, + max_width=100, + scroll_exit=True, + hidden=str(old_opts.model_config_db) == "auto", + ) + self.model_conf_auto.on_changed = self.show_hide_model_conf_override self.nextrely += 1 self.outdir = self.add_widget_intelligent( FileBox, @@ -506,19 +584,21 @@ def create(self): labelColor="GOOD", begin_entry_at=40, max_height=3, + max_width=127, scroll_exit=True, ) self.autoimport_dirs = {} self.autoimport_dirs["autoimport_dir"] = self.add_widget_intelligent( FileBox, - name="Folder to recursively scan for new checkpoints, ControlNets, LoRAs and TI models", - value=str(config.root_path / config.autoimport_dir), + name="Optional folder to scan for new checkpoints, ControlNets, LoRAs and TI models", + value=str(config.root_path / config.autoimport_dir) if config.autoimport_dir else "", select_dir=True, must_exist=False, use_two_lines=False, labelColor="GOOD", begin_entry_at=32, max_height=3, + max_width=127, scroll_exit=True, ) self.nextrely += 1 @@ -555,6 +635,10 @@ def show_hide_slice_sizes(self, value): self.attention_slice_label.hidden = not show self.attention_slice_size.hidden = not show + def show_hide_model_conf_override(self, value): + self.model_conf_override.hidden = value + self.model_conf_override.display() + def on_ok(self): options = self.marshall_arguments() if self.validate_field_values(options): @@ -590,17 +674,21 @@ def marshall_arguments(self): for attr in [ "ram", "vram", + "disk", "outdir", ]: if hasattr(self, attr): setattr(new_opts, attr, getattr(self, attr).value) for attr in self.autoimport_dirs: + if not self.autoimport_dirs[attr].value: + continue directory = Path(self.autoimport_dirs[attr].value) if directory.is_relative_to(config.root_path): directory = directory.relative_to(config.root_path) setattr(new_opts, attr, directory) + new_opts.model_config_db = "auto" if self.model_conf_auto.value else self.model_conf_override.value new_opts.hf_token = self.hf_token.value new_opts.license_acceptance = self.license_acceptance.value new_opts.precision = PRECISION_CHOICES[self.precision.value[0]] @@ -615,13 +703,14 @@ def marshall_arguments(self): class EditOptApplication(npyscreen.NPSAppManaged): - def __init__(self, program_opts: Namespace, invokeai_opts: Namespace): + def __init__(self, program_opts: Namespace, invokeai_opts: Namespace, install_helper: InstallHelper): super().__init__() self.program_opts = program_opts self.invokeai_opts = invokeai_opts self.user_cancelled = False self.autoload_pending = True - self.install_selections = default_user_selections(program_opts) + self.install_helper = install_helper + self.install_selections = default_user_selections(program_opts, install_helper) def onStart(self): npyscreen.setTheme(npyscreen.Themes.DefaultTheme) @@ -644,12 +733,6 @@ def new_opts(self): return self.options.marshall_arguments() -def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Namespace: - editApp = EditOptApplication(program_opts, invokeai_opts) - editApp.run() - return editApp.new_opts() - - def default_ramcache() -> float: """Run a heuristic for the default RAM cache based on installed RAM.""" @@ -666,21 +749,12 @@ def default_startup_options(init_file: Path) -> Namespace: return opts -def default_user_selections(program_opts: Namespace) -> InstallSelections: - try: - installer = ModelInstall(config) - except omegaconf.errors.ConfigKeyError: - logger.warning("Your models.yaml file is corrupt or out of date. Reinitializing") - initialize_rootdir(config.root_path, True) - installer = ModelInstall(config) - - models = installer.all_models() +def default_user_selections(program_opts: Namespace, install_helper: InstallHelper) -> InstallSelections: + default_models = ( + [install_helper.default_model()] if program_opts.default_only else install_helper.recommended_models() + ) return InstallSelections( - install_models=[models[installer.default_model()].path or models[installer.default_model()].repo_id] - if program_opts.default_only - else [models[x].path or models[x].repo_id for x in installer.recommended_models()] - if program_opts.yes_to_all - else list(), + install_models=default_models if program_opts.yes_to_all else list(), ) @@ -730,7 +804,7 @@ def maybe_create_models_yaml(root: Path): # ------------------------------------- -def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace, Namespace): +def run_console_ui(program_opts: Namespace, initfile: Path, install_helper: InstallHelper) -> (Namespace, Namespace): invokeai_opts = default_startup_options(initfile) invokeai_opts.root = program_opts.root @@ -739,13 +813,7 @@ def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace "Could not increase terminal size. Try running again with a larger window or smaller font size." ) - # the install-models application spawns a subprocess to install - # models, and will crash unless this is set before running. - import torch - - torch.multiprocessing.set_start_method("spawn") - - editApp = EditOptApplication(program_opts, invokeai_opts) + editApp = EditOptApplication(program_opts, invokeai_opts, install_helper) editApp.run() if editApp.user_cancelled: return (None, None) @@ -904,6 +972,7 @@ def main(): if opt.full_precision: invoke_args.extend(["--precision", "float32"]) config.parse_args(invoke_args) + config.precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device())) logger = InvokeAILogger().get_logger(config=config) errors = set() @@ -917,14 +986,22 @@ def main(): # run this unconditionally in case new directories need to be added initialize_rootdir(config.root_path, opt.yes_to_all) - models_to_download = default_user_selections(opt) + # this will initialize the models.yaml file if not present + try: + install_helper = InstallHelper(config) + except ConfigFileVersionMismatchException: + config.model_config_db = migrate_models_store(config) + install_helper = InstallHelper(config) + + models_to_download = default_user_selections(opt, install_helper) new_init_file = config.root_path / "invokeai.yaml" if opt.yes_to_all: write_default_options(opt, new_init_file) init_options = Namespace(precision="float32" if opt.full_precision else "float16") + else: - init_options, models_to_download = run_console_ui(opt, new_init_file) + init_options, models_to_download = run_console_ui(opt, new_init_file, install_helper) if init_options: write_opts(init_options, new_init_file) else: @@ -939,10 +1016,12 @@ def main(): if opt.skip_sd_weights: logger.warning("Skipping diffusion weights download per user request") + elif models_to_download: - process_and_execute(opt, models_to_download) + install_helper.add_or_delete(models_to_download) postscript(errors=errors) + if not opt.yes_to_all: input("Press any key to continue...") except WindowTooSmallException as e: diff --git a/invokeai/backend/install/migrate_to_3.py b/invokeai/backend/install/migrate_to_3.py index ea5bee8058a..7d457559bdb 100644 --- a/invokeai/backend/install/migrate_to_3.py +++ b/invokeai/backend/install/migrate_to_3.py @@ -3,13 +3,15 @@ InvokeAI 2.3 installation to 3.0.0. """ +#### NOTE: THIS SCRIPT NO LONGER WORKS WITH REFACTORED MODEL MANAGER, AND WILL NOT BE UPDATED. + import argparse import os import shutil import warnings from dataclasses import dataclass from pathlib import Path -from typing import Union +from typing import Optional, Union import diffusers import transformers @@ -21,8 +23,9 @@ import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_management import ModelManager -from invokeai.backend.model_management.model_probe import BaseModelType, ModelProbe, ModelProbeInfo, ModelType +from invokeai.app.services.model_install_service import ModelInstallService +from invokeai.app.services.model_record_service import ModelRecordServiceBase +from invokeai.backend.model_manager import BaseModelType, ModelProbe, ModelProbeInfo, ModelType warnings.filterwarnings("ignore") transformers.logging.set_verbosity_error() @@ -43,19 +46,14 @@ def __init__( self, from_root: Path, to_models: Path, - model_manager: ModelManager, + installer: ModelInstallService, src_paths: ModelPaths, ): self.root_directory = from_root self.dest_models = to_models - self.mgr = model_manager + self.installer = installer self.src_paths = src_paths - @classmethod - def initialize_yaml(cls, yaml_file: Path): - with open(yaml_file, "w") as file: - file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}})) - def create_directory_structure(self): """ Create the basic directory structure for the models folder. @@ -107,44 +105,10 @@ def migrate_models(self, src_dir: Path): Recursively walk through src directory, probe anything that looks like a model, and copy the model into the appropriate location within the destination models directory. + + This is now trivially easy using the installer service. """ - directories_scanned = set() - for root, dirs, files in os.walk(src_dir, followlinks=True): - for d in dirs: - try: - model = Path(root, d) - info = ModelProbe().heuristic_probe(model) - if not info: - continue - dest = self._model_probe_to_path(info) / model.name - self.copy_dir(model, dest) - directories_scanned.add(model) - except Exception as e: - logger.error(str(e)) - except KeyboardInterrupt: - raise - except Exception as e: - logger.error(str(e)) - for f in files: - # don't copy raw learned_embeds.bin or pytorch_lora_weights.bin - # let them be copied as part of a tree copy operation - try: - if f in {"learned_embeds.bin", "pytorch_lora_weights.bin"}: - continue - model = Path(root, f) - if model.parent in directories_scanned: - continue - info = ModelProbe().heuristic_probe(model) - if not info: - continue - dest = self._model_probe_to_path(info) / f - self.copy_file(model, dest) - except Exception as e: - logger.error(str(e)) - except KeyboardInterrupt: - raise - except Exception as e: - logger.error(str(e)) + self.installer.scan_directory(src_dir) def migrate_support_models(self): """ @@ -260,23 +224,21 @@ def _save_pretrained(self, model, dest: Path, overwrite: bool = False): model.save_pretrained(download_path, safe_serialization=True) download_path.replace(dest) - def _download_vae(self, repo_id: str, subfolder: str = None) -> Path: - vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / "models/hub", subfolder=subfolder) - info = ModelProbe().heuristic_probe(vae) - _, model_name = repo_id.split("/") - dest = self._model_probe_to_path(info) / self.unique_name(model_name, info) - vae.save_pretrained(dest, safe_serialization=True) - return dest + def _download_vae(self, repo_id: str, subfolder: str = None) -> Optional[Path]: + self.installer.install(repo_id) # bug! We don't support subfolder yet. + ids = self.installer.wait_for_installs() + if key := ids.get(repo_id): + return self.installer.store.get_model(key).path + else: + return None - def _vae_path(self, vae: Union[str, dict]) -> Path: - """ - Convert 2.3 VAE stanza to a straight path. - """ - vae_path = None + def _vae_path(self, vae: Union[str, dict]) -> Optional[Path]: + """Convert 2.3 VAE stanza to a straight path.""" + vae_path: Optional[Path] = None # First get a path if isinstance(vae, str): - vae_path = vae + vae_path = Path(vae) elif isinstance(vae, DictConfig): if p := vae.get("path"): @@ -284,28 +246,21 @@ def _vae_path(self, vae: Union[str, dict]) -> Path: elif repo_id := vae.get("repo_id"): if repo_id == "stabilityai/sd-vae-ft-mse": # this guy is already downloaded vae_path = "models/core/convert/sd-vae-ft-mse" - return vae_path + return Path(vae_path) else: vae_path = self._download_vae(repo_id, vae.get("subfolder")) - assert vae_path is not None, "Couldn't find VAE for this model" + if vae_path is None: + return None # if the VAE is in the old models directory, then we must move it into the new # one. VAEs outside of this directory can stay where they are. - vae_path = Path(vae_path) if vae_path.is_relative_to(self.src_paths.models): - info = ModelProbe().heuristic_probe(vae_path) - dest = self._model_probe_to_path(info) / vae_path.name - if not dest.exists(): - if vae_path.is_dir(): - self.copy_dir(vae_path, dest) - else: - self.copy_file(vae_path, dest) - vae_path = dest - - if vae_path.is_relative_to(self.dest_models): - rel_path = vae_path.relative_to(self.dest_models) - return Path("models", rel_path) + key = self.installer.install_path(vae_path) # this will move the model + return self.installer.store.get_model(key).path + elif vae_path.is_relative_to(self.dest_models): + key = self.installer.register_path(vae_path) # this will keep the model in place + return self.installer.store.get_model(key).path else: return vae_path @@ -501,44 +456,27 @@ def get_legacy_embeddings(root: Path) -> ModelPaths: return _parse_legacy_yamlfile(root, path) -def do_migrate(src_directory: Path, dest_directory: Path): +def do_migrate(config: InvokeAIAppConfig, src_directory: Path, dest_directory: Path): """ Migrate models from src to dest InvokeAI root directories """ - config_file = dest_directory / "configs" / "models.yaml.3" dest_models = dest_directory / "models.3" + mm_store = ModelRecordServiceBase.open(config) + mm_install = ModelInstallService(config=config, store=mm_store) version_3 = (dest_directory / "models" / "core").exists() - - # Here we create the destination models.yaml file. - # If we are writing into a version 3 directory and the - # file already exists, then we write into a copy of it to - # avoid deleting its previous customizations. Otherwise we - # create a new empty one. - if version_3: # write into the dest directory - try: - shutil.copy(dest_directory / "configs" / "models.yaml", config_file) - except Exception: - MigrateTo3.initialize_yaml(config_file) - mgr = ModelManager(config_file) # important to initialize BEFORE moving the models directory - (dest_directory / "models").replace(dest_models) - else: - MigrateTo3.initialize_yaml(config_file) - mgr = ModelManager(config_file) + if not version_3: + src_directory = (dest_directory / "models").replace(src_directory / "models.orig") + print(f"Original models directory moved to {dest_directory}/models.orig") paths = get_legacy_embeddings(src_directory) - migrator = MigrateTo3(from_root=src_directory, to_models=dest_models, model_manager=mgr, src_paths=paths) + migrator = MigrateTo3(from_root=src_directory, to_models=dest_models, installer=mm_install, src_paths=paths) migrator.migrate() print("Migration successful.") - if not version_3: - (dest_directory / "models").replace(src_directory / "models.orig") - print(f"Original models directory moved to {dest_directory}/models.orig") - (dest_directory / "configs" / "models.yaml").replace(src_directory / "configs" / "models.yaml.orig") print(f"Original models.yaml file moved to {dest_directory}/configs/models.yaml.orig") - config_file.replace(config_file.with_suffix("")) dest_models.replace(dest_models.with_suffix("")) @@ -588,7 +526,7 @@ def main(): initialize_rootdir(dest_root, True) - do_migrate(src_root, dest_root) + do_migrate(config, src_root, dest_root) if __name__ == "__main__": diff --git a/invokeai/backend/install/model_install_backend.py b/invokeai/backend/install/model_install_backend.py deleted file mode 100644 index 1481300c77f..00000000000 --- a/invokeai/backend/install/model_install_backend.py +++ /dev/null @@ -1,609 +0,0 @@ -""" -Utility (backend) functions used by model_install.py -""" -import os -import re -import shutil -import warnings -from dataclasses import dataclass, field -from pathlib import Path -from tempfile import TemporaryDirectory -from typing import Callable, Dict, List, Optional, Set, Union - -import requests -import torch -from diffusers import DiffusionPipeline -from diffusers import logging as dlogging -from huggingface_hub import HfApi, HfFolder, hf_hub_url -from omegaconf import OmegaConf -from tqdm import tqdm - -import invokeai.configs as configs -from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_management import AddModelResult, BaseModelType, ModelManager, ModelType, ModelVariantType -from invokeai.backend.model_management.model_probe import ModelProbe, ModelProbeInfo, SchedulerPredictionType -from invokeai.backend.util import download_with_resume -from invokeai.backend.util.devices import choose_torch_device, torch_dtype - -from ..util.logging import InvokeAILogger - -warnings.filterwarnings("ignore") - -# --------------------------globals----------------------- -config = InvokeAIAppConfig.get_config() -logger = InvokeAILogger.get_logger(name="InvokeAI") - -# the initial "configs" dir is now bundled in the `invokeai.configs` package -Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml" - -Config_preamble = """ -# This file describes the alternative machine learning models -# available to InvokeAI script. -# -# To add a new model, follow the examples below. Each -# model requires a model config file, a weights file, -# and the width and height of the images it -# was trained on. -""" - -LEGACY_CONFIGS = { - BaseModelType.StableDiffusion1: { - ModelVariantType.Normal: { - SchedulerPredictionType.Epsilon: "v1-inference.yaml", - SchedulerPredictionType.VPrediction: "v1-inference-v.yaml", - }, - ModelVariantType.Inpaint: { - SchedulerPredictionType.Epsilon: "v1-inpainting-inference.yaml", - SchedulerPredictionType.VPrediction: "v1-inpainting-inference-v.yaml", - }, - }, - BaseModelType.StableDiffusion2: { - ModelVariantType.Normal: { - SchedulerPredictionType.Epsilon: "v2-inference.yaml", - SchedulerPredictionType.VPrediction: "v2-inference-v.yaml", - }, - ModelVariantType.Inpaint: { - SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml", - SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml", - }, - }, - BaseModelType.StableDiffusionXL: { - ModelVariantType.Normal: "sd_xl_base.yaml", - }, - BaseModelType.StableDiffusionXLRefiner: { - ModelVariantType.Normal: "sd_xl_refiner.yaml", - }, -} - - -@dataclass -class InstallSelections: - install_models: List[str] = field(default_factory=list) - remove_models: List[str] = field(default_factory=list) - - -@dataclass -class ModelLoadInfo: - name: str - model_type: ModelType - base_type: BaseModelType - path: Optional[Path] = None - repo_id: Optional[str] = None - subfolder: Optional[str] = None - description: str = "" - installed: bool = False - recommended: bool = False - default: bool = False - requires: Optional[List[str]] = field(default_factory=list) - - -class ModelInstall(object): - def __init__( - self, - config: InvokeAIAppConfig, - prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None, - model_manager: Optional[ModelManager] = None, - access_token: Optional[str] = None, - ): - self.config = config - self.mgr = model_manager or ModelManager(config.model_conf_path) - self.datasets = OmegaConf.load(Dataset_path) - self.prediction_helper = prediction_type_helper - self.access_token = access_token or HfFolder.get_token() - self.reverse_paths = self._reverse_paths(self.datasets) - - def all_models(self) -> Dict[str, ModelLoadInfo]: - """ - Return dict of model_key=>ModelLoadInfo objects. - This method consolidates and simplifies the entries in both - models.yaml and INITIAL_MODELS.yaml so that they can - be treated uniformly. It also sorts the models alphabetically - by their name, to improve the display somewhat. - """ - model_dict = dict() - - # first populate with the entries in INITIAL_MODELS.yaml - for key, value in self.datasets.items(): - name, base, model_type = ModelManager.parse_key(key) - value["name"] = name - value["base_type"] = base - value["model_type"] = model_type - model_info = ModelLoadInfo(**value) - if model_info.subfolder and model_info.repo_id: - model_info.repo_id += f":{model_info.subfolder}" - model_dict[key] = model_info - - # supplement with entries in models.yaml - installed_models = [x for x in self.mgr.list_models()] - - for md in installed_models: - base = md["base_model"] - model_type = md["model_type"] - name = md["model_name"] - key = ModelManager.create_key(name, base, model_type) - if key in model_dict: - model_dict[key].installed = True - else: - model_dict[key] = ModelLoadInfo( - name=name, - base_type=base, - model_type=model_type, - path=value.get("path"), - installed=True, - ) - return {x: model_dict[x] for x in sorted(model_dict.keys(), key=lambda y: model_dict[y].name.lower())} - - def _is_autoloaded(self, model_info: dict) -> bool: - path = model_info.get("path") - if not path: - return False - for autodir in ["autoimport_dir", "lora_dir", "embedding_dir", "controlnet_dir"]: - if autodir_path := getattr(self.config, autodir): - autodir_path = self.config.root_path / autodir_path - if Path(path).is_relative_to(autodir_path): - return True - return False - - def list_models(self, model_type): - installed = self.mgr.list_models(model_type=model_type) - print() - print(f"Installed models of type `{model_type}`:") - print(f"{'Model Key':50} Model Path") - for i in installed: - print(f"{'/'.join([i['base_model'],i['model_type'],i['model_name']]):50} {i['path']}") - print() - - # logic here a little reversed to maintain backward compatibility - def starter_models(self, all_models: bool = False) -> Set[str]: - models = set() - for key, value in self.datasets.items(): - name, base, model_type = ModelManager.parse_key(key) - if all_models or model_type in [ModelType.Main, ModelType.Vae]: - models.add(key) - return models - - def recommended_models(self) -> Set[str]: - starters = self.starter_models(all_models=True) - return set([x for x in starters if self.datasets[x].get("recommended", False)]) - - def default_model(self) -> str: - starters = self.starter_models() - defaults = [x for x in starters if self.datasets[x].get("default", False)] - return defaults[0] - - def install(self, selections: InstallSelections): - verbosity = dlogging.get_verbosity() # quench NSFW nags - dlogging.set_verbosity_error() - - job = 1 - jobs = len(selections.remove_models) + len(selections.install_models) - - # remove requested models - for key in selections.remove_models: - name, base, mtype = self.mgr.parse_key(key) - logger.info(f"Deleting {mtype} model {name} [{job}/{jobs}]") - try: - self.mgr.del_model(name, base, mtype) - except FileNotFoundError as e: - logger.warning(e) - job += 1 - - # add requested models - self._remove_installed(selections.install_models) - self._add_required_models(selections.install_models) - for path in selections.install_models: - logger.info(f"Installing {path} [{job}/{jobs}]") - try: - self.heuristic_import(path) - except (ValueError, KeyError) as e: - logger.error(str(e)) - job += 1 - - dlogging.set_verbosity(verbosity) - self.mgr.commit() - - def heuristic_import( - self, - model_path_id_or_url: Union[str, Path], - models_installed: Set[Path] = None, - ) -> Dict[str, AddModelResult]: - """ - :param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL - :param models_installed: Set of installed models, used for recursive invocation - Returns a set of dict objects corresponding to newly-created stanzas in models.yaml. - """ - - if not models_installed: - models_installed = dict() - - # A little hack to allow nested routines to retrieve info on the requested ID - self.current_id = model_path_id_or_url - path = Path(model_path_id_or_url) - # checkpoint file, or similar - if path.is_file(): - models_installed.update({str(path): self._install_path(path)}) - - # folders style or similar - elif path.is_dir() and any( - [ - (path / x).exists() - for x in {"config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"} - ] - ): - models_installed.update({str(model_path_id_or_url): self._install_path(path)}) - - # recursive scan - elif path.is_dir(): - for child in path.iterdir(): - self.heuristic_import(child, models_installed=models_installed) - - # huggingface repo - elif len(str(model_path_id_or_url).split("/")) == 2: - models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))}) - - # a URL - elif str(model_path_id_or_url).startswith(("http:", "https:", "ftp:")): - models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)}) - - else: - raise KeyError(f"{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping") - - return models_installed - - def _remove_installed(self, model_list: List[str]): - all_models = self.all_models() - for path in model_list: - key = self.reverse_paths.get(path) - if key and all_models[key].installed: - logger.warning(f"{path} already installed. Skipping.") - model_list.remove(path) - - def _add_required_models(self, model_list: List[str]): - additional_models = [] - all_models = self.all_models() - for path in model_list: - if not (key := self.reverse_paths.get(path)): - continue - for requirement in all_models[key].requires: - requirement_key = self.reverse_paths.get(requirement) - if not all_models[requirement_key].installed: - additional_models.append(requirement) - model_list.extend(additional_models) - - # install a model from a local path. The optional info parameter is there to prevent - # the model from being probed twice in the event that it has already been probed. - def _install_path(self, path: Path, info: ModelProbeInfo = None) -> AddModelResult: - info = info or ModelProbe().heuristic_probe(path, self.prediction_helper) - if not info: - logger.warning(f"Unable to parse format of {path}") - return None - model_name = path.stem if path.is_file() else path.name - if self.mgr.model_exists(model_name, info.base_type, info.model_type): - raise ValueError(f'A model named "{model_name}" is already installed.') - attributes = self._make_attributes(path, info) - return self.mgr.add_model( - model_name=model_name, - base_model=info.base_type, - model_type=info.model_type, - model_attributes=attributes, - ) - - def _install_url(self, url: str) -> AddModelResult: - with TemporaryDirectory(dir=self.config.models_path) as staging: - location = download_with_resume(url, Path(staging)) - if not location: - logger.error(f"Unable to download {url}. Skipping.") - info = ModelProbe().heuristic_probe(location, self.prediction_helper) - dest = self.config.models_path / info.base_type.value / info.model_type.value / location.name - dest.parent.mkdir(parents=True, exist_ok=True) - models_path = shutil.move(location, dest) - - # staged version will be garbage-collected at this time - return self._install_path(Path(models_path), info) - - def _install_repo(self, repo_id: str) -> AddModelResult: - # hack to recover models stored in subfolders -- - # Required to get the "v2" model of monster-labs/control_v1p_sd15_qrcode_monster - subfolder = None - if match := re.match(r"^([^/]+/[^/]+):(\w+)$", repo_id): - repo_id = match.group(1) - subfolder = match.group(2) - - hinfo = HfApi().model_info(repo_id) - - # we try to figure out how to download this most economically - # list all the files in the repo - files = [x.rfilename for x in hinfo.siblings] - if subfolder: - files = [x for x in files if x.startswith(f"{subfolder}/")] - prefix = f"{subfolder}/" if subfolder else "" - - location = None - - with TemporaryDirectory(dir=self.config.models_path) as staging: - staging = Path(staging) - if f"{prefix}model_index.json" in files: - location = self._download_hf_pipeline(repo_id, staging, subfolder=subfolder) # pipeline - elif f"{prefix}unet/model.onnx" in files: - location = self._download_hf_model(repo_id, files, staging) - else: - for suffix in ["safetensors", "bin"]: - if f"{prefix}pytorch_lora_weights.{suffix}" in files: - location = self._download_hf_model( - repo_id, ["pytorch_lora_weights.bin"], staging, subfolder=subfolder - ) # LoRA - break - elif ( - self.config.precision == "float16" and f"{prefix}diffusion_pytorch_model.fp16.{suffix}" in files - ): # vae, controlnet or some other standalone - files = ["config.json", f"diffusion_pytorch_model.fp16.{suffix}"] - location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder) - break - elif f"{prefix}diffusion_pytorch_model.{suffix}" in files: - files = ["config.json", f"diffusion_pytorch_model.{suffix}"] - location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder) - break - elif f"{prefix}learned_embeds.{suffix}" in files: - location = self._download_hf_model( - repo_id, [f"learned_embeds.{suffix}"], staging, subfolder=subfolder - ) - break - elif ( - f"{prefix}image_encoder.txt" in files and f"{prefix}ip_adapter.{suffix}" in files - ): # IP-Adapter - files = ["image_encoder.txt", f"ip_adapter.{suffix}"] - location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder) - break - elif f"{prefix}model.{suffix}" in files and f"{prefix}config.json" in files: - # This elif-condition is pretty fragile, but it is intended to handle CLIP Vision models hosted - # by InvokeAI for use with IP-Adapters. - files = ["config.json", f"model.{suffix}"] - location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder) - break - if not location: - logger.warning(f"Could not determine type of repo {repo_id}. Skipping install.") - return {} - - info = ModelProbe().heuristic_probe(location, self.prediction_helper) - if not info: - logger.warning(f"Could not probe {location}. Skipping install.") - return {} - dest = ( - self.config.models_path - / info.base_type.value - / info.model_type.value - / self._get_model_name(repo_id, location) - ) - if dest.exists(): - shutil.rmtree(dest) - shutil.copytree(location, dest) - return self._install_path(dest, info) - - def _get_model_name(self, path_name: str, location: Path) -> str: - """ - Calculate a name for the model - primitive implementation. - """ - if key := self.reverse_paths.get(path_name): - (name, base, mtype) = ModelManager.parse_key(key) - return name - elif location.is_dir(): - return location.name - else: - return location.stem - - def _make_attributes(self, path: Path, info: ModelProbeInfo) -> dict: - model_name = path.name if path.is_dir() else path.stem - description = f"{info.base_type.value} {info.model_type.value} model {model_name}" - if key := self.reverse_paths.get(self.current_id): - if key in self.datasets: - description = self.datasets[key].get("description") or description - - rel_path = self.relative_to_root(path, self.config.models_path) - - attributes = dict( - path=str(rel_path), - description=str(description), - model_format=info.format, - ) - legacy_conf = None - if info.model_type == ModelType.Main or info.model_type == ModelType.ONNX: - attributes.update( - dict( - variant=info.variant_type, - ) - ) - if info.format == "checkpoint": - try: - possible_conf = path.with_suffix(".yaml") - if possible_conf.exists(): - legacy_conf = str(self.relative_to_root(possible_conf)) - elif info.base_type in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]: - legacy_conf = Path( - self.config.legacy_conf_dir, - LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type], - ) - else: - legacy_conf = Path( - self.config.legacy_conf_dir, LEGACY_CONFIGS[info.base_type][info.variant_type] - ) - except KeyError: - legacy_conf = Path(self.config.legacy_conf_dir, "v1-inference.yaml") # best guess - - if info.model_type == ModelType.ControlNet and info.format == "checkpoint": - possible_conf = path.with_suffix(".yaml") - if possible_conf.exists(): - legacy_conf = str(self.relative_to_root(possible_conf)) - - if legacy_conf: - attributes.update(dict(config=str(legacy_conf))) - return attributes - - def relative_to_root(self, path: Path, root: Optional[Path] = None) -> Path: - root = root or self.config.root_path - if path.is_relative_to(root): - return path.relative_to(root) - else: - return path - - def _download_hf_pipeline(self, repo_id: str, staging: Path, subfolder: str = None) -> Path: - """ - Retrieve a StableDiffusion model from cache or remote and then - does a save_pretrained() to the indicated staging area. - """ - _, name = repo_id.split("/") - precision = torch_dtype(choose_torch_device()) - variants = ["fp16", None] if precision == torch.float16 else [None, "fp16"] - - model = None - for variant in variants: - try: - model = DiffusionPipeline.from_pretrained( - repo_id, - variant=variant, - torch_dtype=precision, - safety_checker=None, - subfolder=subfolder, - ) - except Exception as e: # most errors are due to fp16 not being present. Fix this to catch other errors - if "fp16" not in str(e): - print(e) - - if model: - break - - if not model: - logger.error(f"Diffusers model {repo_id} could not be downloaded. Skipping.") - return None - model.save_pretrained(staging / name, safe_serialization=True) - return staging / name - - def _download_hf_model(self, repo_id: str, files: List[str], staging: Path, subfolder: None) -> Path: - _, name = repo_id.split("/") - location = staging / name - paths = list() - for filename in files: - filePath = Path(filename) - p = hf_download_with_resume( - repo_id, - model_dir=location / filePath.parent, - model_name=filePath.name, - access_token=self.access_token, - subfolder=filePath.parent / subfolder if subfolder else filePath.parent, - ) - if p: - paths.append(p) - else: - logger.warning(f"Could not download {filename} from {repo_id}.") - - return location if len(paths) > 0 else None - - @classmethod - def _reverse_paths(cls, datasets) -> dict: - """ - Reverse mapping from repo_id/path to destination name. - """ - return {v.get("path") or v.get("repo_id"): k for k, v in datasets.items()} - - -# ------------------------------------- -def yes_or_no(prompt: str, default_yes=True): - default = "y" if default_yes else "n" - response = input(f"{prompt} [{default}] ") or default - if default_yes: - return response[0] not in ("n", "N") - else: - return response[0] in ("y", "Y") - - -# --------------------------------------------- -def hf_download_from_pretrained(model_class: object, model_name: str, destination: Path, **kwargs): - logger = InvokeAILogger.get_logger("InvokeAI") - logger.addFilter(lambda x: "fp16 is not a valid" not in x.getMessage()) - - model = model_class.from_pretrained( - model_name, - resume_download=True, - **kwargs, - ) - model.save_pretrained(destination, safe_serialization=True) - return destination - - -# --------------------------------------------- -def hf_download_with_resume( - repo_id: str, - model_dir: str, - model_name: str, - model_dest: Path = None, - access_token: str = None, - subfolder: str = None, -) -> Path: - model_dest = model_dest or Path(os.path.join(model_dir, model_name)) - os.makedirs(model_dir, exist_ok=True) - - url = hf_hub_url(repo_id, model_name, subfolder=subfolder) - - header = {"Authorization": f"Bearer {access_token}"} if access_token else {} - open_mode = "wb" - exist_size = 0 - - if os.path.exists(model_dest): - exist_size = os.path.getsize(model_dest) - header["Range"] = f"bytes={exist_size}-" - open_mode = "ab" - - resp = requests.get(url, headers=header, stream=True) - total = int(resp.headers.get("content-length", 0)) - - if resp.status_code == 416: # "range not satisfiable", which means nothing to return - logger.info(f"{model_name}: complete file found. Skipping.") - return model_dest - elif resp.status_code == 404: - logger.warning("File not found") - return None - elif resp.status_code != 200: - logger.warning(f"{model_name}: {resp.reason}") - elif exist_size > 0: - logger.info(f"{model_name}: partial file found. Resuming...") - else: - logger.info(f"{model_name}: Downloading...") - - try: - with ( - open(model_dest, open_mode) as file, - tqdm( - desc=model_name, - initial=exist_size, - total=total + exist_size, - unit="iB", - unit_scale=True, - unit_divisor=1000, - ) as bar, - ): - for data in resp.iter_content(chunk_size=1024): - size = file.write(data) - bar.update(size) - except Exception as e: - logger.error(f"An error occurred while downloading {model_name}: {str(e)}") - return None - return model_dest diff --git a/invokeai/backend/ip_adapter/ip_adapter.py b/invokeai/backend/ip_adapter/ip_adapter.py index 826112156df..1cb6b85c665 100644 --- a/invokeai/backend/ip_adapter/ip_adapter.py +++ b/invokeai/backend/ip_adapter/ip_adapter.py @@ -8,7 +8,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights -from invokeai.backend.model_management.models.base import calc_model_size_by_data +from invokeai.backend.model_manager.models.base import calc_model_size_by_data from .resampler import Resampler diff --git a/invokeai/backend/model_management/README b/invokeai/backend/model_management/README new file mode 100644 index 00000000000..c7388df72e6 --- /dev/null +++ b/invokeai/backend/model_management/README @@ -0,0 +1 @@ +The contents of this directory are deprecated. model_manager.py is here only for reference. diff --git a/invokeai/backend/model_management/model_search.py b/invokeai/backend/model_management/model_search.py deleted file mode 100644 index 7e6b37c832c..00000000000 --- a/invokeai/backend/model_management/model_search.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright 2023, Lincoln D. Stein and the InvokeAI Team -""" -Abstract base class for recursive directory search for models. -""" - -import os -from abc import ABC, abstractmethod -from pathlib import Path -from typing import List, Set, types - -import invokeai.backend.util.logging as logger - - -class ModelSearch(ABC): - def __init__(self, directories: List[Path], logger: types.ModuleType = logger): - """ - Initialize a recursive model directory search. - :param directories: List of directory Paths to recurse through - :param logger: Logger to use - """ - self.directories = directories - self.logger = logger - self._items_scanned = 0 - self._models_found = 0 - self._scanned_dirs = set() - self._scanned_paths = set() - self._pruned_paths = set() - - @abstractmethod - def on_search_started(self): - """ - Called before the scan starts. - """ - pass - - @abstractmethod - def on_model_found(self, model: Path): - """ - Process a found model. Raise an exception if something goes wrong. - :param model: Model to process - could be a directory or checkpoint. - """ - pass - - @abstractmethod - def on_search_completed(self): - """ - Perform some activity when the scan is completed. May use instance - variables, items_scanned and models_found - """ - pass - - def search(self): - self.on_search_started() - for dir in self.directories: - self.walk_directory(dir) - self.on_search_completed() - - def walk_directory(self, path: Path): - for root, dirs, files in os.walk(path, followlinks=True): - if str(Path(root).name).startswith("."): - self._pruned_paths.add(root) - if any([Path(root).is_relative_to(x) for x in self._pruned_paths]): - continue - - self._items_scanned += len(dirs) + len(files) - for d in dirs: - path = Path(root) / d - if path in self._scanned_paths or path.parent in self._scanned_dirs: - self._scanned_dirs.add(path) - continue - if any( - [ - (path / x).exists() - for x in { - "config.json", - "model_index.json", - "learned_embeds.bin", - "pytorch_lora_weights.bin", - "image_encoder.txt", - } - ] - ): - try: - self.on_model_found(path) - self._models_found += 1 - self._scanned_dirs.add(path) - except Exception as e: - self.logger.warning(f"Failed to process '{path}': {e}") - - for f in files: - path = Path(root) / f - if path.parent in self._scanned_dirs: - continue - if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}: - try: - self.on_model_found(path) - self._models_found += 1 - except Exception as e: - self.logger.warning(f"Failed to process '{path}': {e}") - - -class FindModels(ModelSearch): - def on_search_started(self): - self.models_found: Set[Path] = set() - - def on_model_found(self, model: Path): - self.models_found.add(model) - - def on_search_completed(self): - pass - - def list_models(self) -> List[Path]: - self.search() - return list(self.models_found) diff --git a/invokeai/backend/model_management/util.py b/invokeai/backend/model_management/util.py deleted file mode 100644 index 6d70107c934..00000000000 --- a/invokeai/backend/model_management/util.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright (c) 2023 The InvokeAI Development Team -"""Utilities used by the Model Manager""" - - -def lora_token_vector_length(checkpoint: dict) -> int: - """ - Given a checkpoint in memory, return the lora token vector length - - :param checkpoint: The checkpoint - """ - - def _get_shape_1(key, tensor, checkpoint): - lora_token_vector_length = None - - if "." not in key: - return lora_token_vector_length # wrong key format - model_key, lora_key = key.split(".", 1) - - # check lora/locon - if lora_key == "lora_down.weight": - lora_token_vector_length = tensor.shape[1] - - # check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes) - elif lora_key in ["hada_w1_b", "hada_w2_b"]: - lora_token_vector_length = tensor.shape[1] - - # check lokr (don't worry about lokr_t2 as it used only in 4d shapes) - elif "lokr_" in lora_key: - if model_key + ".lokr_w1" in checkpoint: - _lokr_w1 = checkpoint[model_key + ".lokr_w1"] - elif model_key + "lokr_w1_b" in checkpoint: - _lokr_w1 = checkpoint[model_key + ".lokr_w1_b"] - else: - return lora_token_vector_length # unknown format - - if model_key + ".lokr_w2" in checkpoint: - _lokr_w2 = checkpoint[model_key + ".lokr_w2"] - elif model_key + "lokr_w2_b" in checkpoint: - _lokr_w2 = checkpoint[model_key + ".lokr_w2_b"] - else: - return lora_token_vector_length # unknown format - - lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1] - - elif lora_key == "diff": - lora_token_vector_length = tensor.shape[1] - - # ia3 can be detected only by shape[0] in text encoder - elif lora_key == "weight" and "lora_unet_" not in model_key: - lora_token_vector_length = tensor.shape[0] - - return lora_token_vector_length - - lora_token_vector_length = None - lora_te1_length = None - lora_te2_length = None - for key, tensor in checkpoint.items(): - if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key): - lora_token_vector_length = _get_shape_1(key, tensor, checkpoint) - elif key.startswith("lora_te") and "_self_attn_" in key: - tmp_length = _get_shape_1(key, tensor, checkpoint) - if key.startswith("lora_te_"): - lora_token_vector_length = tmp_length - elif key.startswith("lora_te1_"): - lora_te1_length = tmp_length - elif key.startswith("lora_te2_"): - lora_te2_length = tmp_length - - if lora_te1_length is not None and lora_te2_length is not None: - lora_token_vector_length = lora_te1_length + lora_te2_length - - if lora_token_vector_length is not None: - break - - return lora_token_vector_length diff --git a/invokeai/backend/model_manager/__init__.py b/invokeai/backend/model_manager/__init__.py new file mode 100644 index 00000000000..c8247099ab0 --- /dev/null +++ b/invokeai/backend/model_manager/__init__.py @@ -0,0 +1,27 @@ +"""Initialization file for invokeai.backend.model_manager.config.""" +from .config import ( # noqa F401 + BaseModelType, + InvalidModelConfigException, + ModelConfigBase, + ModelConfigFactory, + ModelFormat, + ModelType, + ModelVariantType, + SchedulerPredictionType, + SilenceWarnings, + SubModelType, +) + +# from .install import ModelInstall, ModelInstallJob # noqa F401 +# from .loader import ModelInfo, ModelLoad # noqa F401 +# from .lora import ModelPatcher, ONNXModelPatcher # noqa F401 +from .models import OPENAPI_MODEL_CONFIGS, InvalidModelException, read_checkpoint_meta # noqa F401 +from .probe import ModelProbe, ModelProbeInfo # noqa F401 +from .search import ModelSearch # noqa F401 +from .storage import ( # noqa F401 + DuplicateModelException, + ModelConfigStore, + ModelConfigStoreSQL, + ModelConfigStoreYAML, + UnknownModelException, +) diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_manager/cache.py similarity index 90% rename from invokeai/backend/model_management/model_cache.py rename to invokeai/backend/model_manager/cache.py index 8cb6b55caf7..248723e046a 100644 --- a/invokeai/backend/model_management/model_cache.py +++ b/invokeai/backend/model_manager/cache.py @@ -1,5 +1,6 @@ """ Manage a RAM cache of diffusion/transformer models for fast switching. + They are moved between GPU VRAM and CPU RAM as necessary. If the cache grows larger than a preset maximum, then the least recently used model will be cleared and (re)loaded from disk when next needed. @@ -25,13 +26,14 @@ from contextlib import suppress from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Dict, Optional, Type, Union, types +from typing import Any, Dict, List, Optional, Type, Union import torch -import invokeai.backend.util.logging as logger -from invokeai.backend.model_management.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff +from invokeai.backend.model_manager.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff +from invokeai.backend.util import InvokeAILogger, Logger +from ..util import GIG from ..util.devices import choose_torch_device from .models import BaseModelType, ModelBase, ModelType, SubModelType @@ -63,20 +65,10 @@ class CacheStats(object): loaded_model_sizes: Dict[str, int] = field(default_factory=dict) -class ModelLocker(object): - "Forward declaration" - pass - - -class ModelCache(object): - "Forward declaration" - pass - - class _CacheRecord: size: int model: Any - cache: ModelCache + cache: "ModelCache" _locks: int def __init__(self, cache, model: Any, size: int): @@ -112,10 +104,9 @@ def __init__( execution_device: torch.device = torch.device("cuda"), storage_device: torch.device = torch.device("cpu"), precision: torch.dtype = torch.float16, - sequential_offload: bool = False, lazy_offloading: bool = True, sha_chunksize: int = 16777216, - logger: types.ModuleType = logger, + logger: Logger = InvokeAILogger.get_logger(), ): """ :param max_cache_size: Maximum size of the RAM cache [6.0 GB] @@ -123,7 +114,6 @@ def __init__( :param storage_device: Torch device to save inactive model in [torch.device('cpu')] :param precision: Precision for loaded models [torch.float16] :param lazy_offloading: Keep model in VRAM until another model needs to be loaded - :param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially :param sha_chunksize: Chunksize to use when calculating sha256 model hash """ self.model_infos: Dict[str, ModelBase] = dict() @@ -138,40 +128,37 @@ def __init__( self.logger = logger # used for stats collection - self.stats = None + self.stats: Optional[CacheStats] = None - self._cached_models = dict() - self._cache_stack = list() + self._cached_models: Dict[str, _CacheRecord] = dict() + self._cache_stack: List[str] = list() + # Note that the combination of model_path and submodel_type + # are sufficient to generate a unique cache key. This key + # is not the same as the unique hash used to identify models + # in invokeai.backend.model_manager.storage def get_key( self, - model_path: str, - base_model: BaseModelType, - model_type: ModelType, + model_path: Path, submodel_type: Optional[SubModelType] = None, ): - key = f"{model_path}:{base_model}:{model_type}" + key = model_path.as_posix() if submodel_type: key += f":{submodel_type}" return key def _get_model_info( self, - model_path: str, + model_path: Path, model_class: Type[ModelBase], base_model: BaseModelType, model_type: ModelType, ): - model_info_key = self.get_key( - model_path=model_path, - base_model=base_model, - model_type=model_type, - submodel_type=None, - ) + model_info_key = self.get_key(model_path=model_path) if model_info_key not in self.model_infos: self.model_infos[model_info_key] = model_class( - model_path, + model_path.as_posix(), base_model, model_type, ) @@ -200,12 +187,8 @@ def get_model( base_model=base_model, model_type=model_type, ) - key = self.get_key( - model_path=model_path, - base_model=base_model, - model_type=model_type, - submodel_type=submodel, - ) + key = self.get_key(model_path, submodel) + # TODO: lock for no copies on simultaneous calls? cache_entry = self._cached_models.get(key, None) if cache_entry is None: @@ -253,7 +236,7 @@ def get_model( self.stats.hits += 1 if self.stats: - self.stats.cache_size = self.max_cache_size * GIG + self.stats.cache_size = int(self.max_cache_size * GIG) self.stats.high_watermark = max(self.stats.high_watermark, self._cache_size()) self.stats.in_cache = len(self._cached_models) self.stats.loaded_model_sizes[key] = max( @@ -306,8 +289,12 @@ def _move_model_to_device(self, key: str, target_device: torch.device): ) class ModelLocker(object): + """Context manager that locks models into VRAM.""" + def __init__(self, cache, key, model, gpu_load, size_needed): """ + Initialize a context manager object that locks models into VRAM. + :param cache: The model_cache object :param key: The key of the model to lock in GPU :param model: The model to lock @@ -366,18 +353,6 @@ def uncache_model(self, cache_id: str): self._cache_stack.remove(cache_id) self._cached_models.pop(cache_id, None) - def model_hash( - self, - model_path: Union[str, Path], - ) -> str: - """ - Given the HF repo id or path to a model on disk, returns a unique - hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs - - :param model_path: Path to model file/directory on disk. - """ - return self._local_model_hash(model_path) - def cache_size(self) -> float: """Return the current size of the cache, in GB.""" return self._cache_size() / GIG @@ -429,8 +404,8 @@ def _make_cache_room(self, model_size): refs = sys.getrefcount(cache_entry.model) - # manualy clear local variable references of just finished function calls - # for some reason python don't want to collect it even by gc.collect() immidiately + # Manually clear local variable references of just finished function calls. + # For some reason python doesn't want to garbage collect it even when gc.collect() is called if refs > 2: while True: cleared = False diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py new file mode 100644 index 00000000000..6540f7d51bd --- /dev/null +++ b/invokeai/backend/model_manager/config.py @@ -0,0 +1,366 @@ +# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team +""" +Configuration definitions for image generation models. + +Typical usage: + + from invokeai.backend.model_manager import ModelConfigFactory + raw = dict(path='models/sd-1/main/foo.ckpt', + name='foo', + base_model='sd-1', + model_type='main', + config='configs/stable-diffusion/v1-inference.yaml', + variant='normal', + model_format='checkpoint' + ) + config = ModelConfigFactory.make_config(raw) + print(config.name) + +Validation errors will raise an InvalidModelConfigException error. + +""" +import warnings +from enum import Enum +from typing import List, Literal, Optional, Type, Union + +import pydantic + +# import these so that we can silence them +from diffusers import logging as diffusers_logging +from omegaconf.listconfig import ListConfig # to support the yaml backend +from pydantic import BaseModel, Extra, Field +from pydantic.error_wrappers import ValidationError +from transformers import logging as transformers_logging + + +class InvalidModelConfigException(Exception): + """Exception for when config parser doesn't recognized this combination of model type and format.""" + + +class BaseModelType(str, Enum): + """Base model type.""" + + Any = "any" + StableDiffusion1 = "sd-1" + StableDiffusion2 = "sd-2" + StableDiffusionXL = "sdxl" + StableDiffusionXLRefiner = "sdxl-refiner" + # Kandinsky2_1 = "kandinsky-2.1" + + +class ModelType(str, Enum): + """Model type.""" + + ONNX = "onnx" + Main = "main" + Vae = "vae" + Lora = "lora" + ControlNet = "controlnet" # used by model_probe + TextualInversion = "embedding" + IPAdapter = "ip_adapter" + CLIPVision = "clip_vision" + T2IAdapter = "t2i_adapter" + + +class SubModelType(str, Enum): + """Submodel type.""" + + UNet = "unet" + TextEncoder = "text_encoder" + TextEncoder2 = "text_encoder_2" + Tokenizer = "tokenizer" + Tokenizer2 = "tokenizer_2" + Vae = "vae" + VaeDecoder = "vae_decoder" + VaeEncoder = "vae_encoder" + Scheduler = "scheduler" + SafetyChecker = "safety_checker" + + +class ModelVariantType(str, Enum): + """Variant type.""" + + Normal = "normal" + Inpaint = "inpaint" + Depth = "depth" + + +class ModelFormat(str, Enum): + """Storage format of model.""" + + Diffusers = "diffusers" + Checkpoint = "checkpoint" + Lycoris = "lycoris" + Onnx = "onnx" + Olive = "olive" + EmbeddingFile = "embedding_file" + EmbeddingFolder = "embedding_folder" + InvokeAI = "invokeai" + + +class SchedulerPredictionType(str, Enum): + """Scheduler prediction type.""" + + Epsilon = "epsilon" + VPrediction = "v_prediction" + Sample = "sample" + + +# TODO: use this +class ModelError(str, Enum): + NotFound = "not_found" + + +class ModelConfigBase(BaseModel): + """Base class for model configuration information.""" + + path: str + name: str + base_model: BaseModelType + model_type: ModelType + model_format: ModelFormat + key: str = Field( + description="key for model derived from original hash", default="" + ) # assigned on the first install + hash: Optional[str] = Field( + description="current hash key for model", default=None + ) # if model is converted or otherwise modified, this will hold updated hash + description: Optional[str] = Field(None) + author: Optional[str] = Field(description="Model author") + license: Optional[str] = Field(description="License string") + source: Optional[str] = Field(description="Model download source (URL or repo_id)") + thumbnail_url: Optional[str] = Field(description="URL of thumbnail image") + tags: Optional[List[str]] = Field(description="Descriptive tags") # Set would be better, but not JSON serializable + + class Config: + """Pydantic configuration hint.""" + + use_enum_values = False + extra = Extra.forbid + validate_assignment = True + + @pydantic.validator("tags", pre=True) + @classmethod + def _fix_tags(cls, v): + if isinstance(v, ListConfig): # to support yaml backend + v = list(v) + return v + + def update(self, attributes: dict): + """Update the object with fields in dict.""" + for key, value in attributes.items(): + setattr(self, key, value) # may raise a validation error + + +class CheckpointConfig(ModelConfigBase): + """Model config for checkpoint-style models.""" + + model_format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint + config: str = Field(description="path to the checkpoint model config file") + + +class DiffusersConfig(ModelConfigBase): + """Model config for diffusers-style models.""" + + model_format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers + + +class LoRAConfig(ModelConfigBase): + """Model config for LoRA/Lycoris models.""" + + model_format: Literal[ModelFormat.Lycoris, ModelFormat.Diffusers] + + +class VaeCheckpointConfig(ModelConfigBase): + """Model config for standalone VAE models.""" + + model_format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint + + +class VaeDiffusersConfig(ModelConfigBase): + """Model config for standalone VAE models (diffusers version).""" + + model_format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers + + +class ControlNetDiffusersConfig(DiffusersConfig): + """Model config for ControlNet models (diffusers version).""" + + model_format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers + + +class ControlNetCheckpointConfig(CheckpointConfig): + """Model config for ControlNet models (diffusers version).""" + + model_format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint + + +class TextualInversionConfig(ModelConfigBase): + """Model config for textual inversion embeddings.""" + + model_format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder] + + +class MainConfig(ModelConfigBase): + """Model config for main models.""" + + vae: Optional[str] = Field(None) + variant: ModelVariantType = ModelVariantType.Normal + + +class MainCheckpointConfig(CheckpointConfig, MainConfig): + """Model config for main checkpoint models.""" + + +class MainDiffusersConfig(DiffusersConfig, MainConfig): + """Model config for main diffusers models.""" + + +class ONNXSD1Config(MainConfig): + """Model config for ONNX format models based on sd-1.""" + + model_format: Literal[ModelFormat.Onnx, ModelFormat.Olive] + + +class ONNXSD2Config(MainConfig): + """Model config for ONNX format models based on sd-2.""" + + model_format: Literal[ModelFormat.Onnx, ModelFormat.Olive] + # No yaml config file for ONNX, so these are part of config + prediction_type: SchedulerPredictionType + upcast_attention: bool + + +class IPAdapterConfig(ModelConfigBase): + """Model config for IP Adaptor format models.""" + + model_format: Literal[ModelFormat.InvokeAI] + + +class CLIPVisionDiffusersConfig(ModelConfigBase): + """Model config for ClipVision.""" + + model_format: Literal[ModelFormat.Diffusers] + + +class T2IConfig(ModelConfigBase): + """Model config for T2I.""" + + model_format: Literal[ModelFormat.Diffusers] + + +AnyModelConfig = Union[ + ModelConfigBase, + MainCheckpointConfig, + MainDiffusersConfig, + LoRAConfig, + TextualInversionConfig, + ONNXSD1Config, + ONNXSD2Config, + VaeCheckpointConfig, + VaeDiffusersConfig, + ControlNetDiffusersConfig, + ControlNetCheckpointConfig, + IPAdapterConfig, + CLIPVisionDiffusersConfig, + T2IConfig, +] + + +class ModelConfigFactory(object): + """Class for parsing config dicts into StableDiffusion Config obects.""" + + _class_map: dict = { + ModelFormat.Checkpoint: { + ModelType.Main: MainCheckpointConfig, + ModelType.Vae: VaeCheckpointConfig, + }, + ModelFormat.Diffusers: { + ModelType.Main: MainDiffusersConfig, + ModelType.Lora: LoRAConfig, + ModelType.Vae: VaeDiffusersConfig, + ModelType.ControlNet: ControlNetDiffusersConfig, + ModelType.CLIPVision: CLIPVisionDiffusersConfig, + }, + ModelFormat.Lycoris: { + ModelType.Lora: LoRAConfig, + }, + ModelFormat.Onnx: { + ModelType.ONNX: { + BaseModelType.StableDiffusion1: ONNXSD1Config, + BaseModelType.StableDiffusion2: ONNXSD2Config, + }, + }, + ModelFormat.Olive: { + ModelType.ONNX: { + BaseModelType.StableDiffusion1: ONNXSD1Config, + BaseModelType.StableDiffusion2: ONNXSD2Config, + }, + }, + ModelFormat.EmbeddingFile: { + ModelType.TextualInversion: TextualInversionConfig, + }, + ModelFormat.EmbeddingFolder: { + ModelType.TextualInversion: TextualInversionConfig, + }, + ModelFormat.InvokeAI: { + ModelType.IPAdapter: IPAdapterConfig, + }, + } + + @classmethod + def make_config( + cls, + model_data: Union[dict, ModelConfigBase], + key: Optional[str] = None, + dest_class: Optional[Type] = None, + ) -> AnyModelConfig: + """ + Return the appropriate config object from raw dict values. + + :param model_data: A raw dict corresponding the obect fields to be + parsed into a ModelConfigBase obect (or descendent), or a ModelConfigBase + object, which will be passed through unchanged. + :param dest_class: The config class to be returned. If not provided, will + be selected automatically. + """ + if isinstance(model_data, ModelConfigBase): + if key: + model_data.key = key + return model_data + try: + model_format = model_data.get("model_format") + model_type = model_data.get("model_type") + model_base = model_data.get("base_model") + class_to_return = dest_class or cls._class_map[model_format][model_type] + if isinstance(class_to_return, dict): # additional level allowed + class_to_return = class_to_return[model_base] + model = class_to_return.parse_obj(model_data) + if key: + model.key = key # ensure consistency + return model + except KeyError as exc: + raise InvalidModelConfigException( + f"Unknown combination of model_format '{model_format}' and model_type '{model_type}'" + ) from exc + except ValidationError as exc: + raise InvalidModelConfigException(f"Invalid model configuration passed: {str(exc)}") from exc + + +# TO DO: Move this somewhere else +class SilenceWarnings(object): + """Context manager to temporarily lower verbosity of diffusers & transformers warning messages.""" + + def __init__(self): + self.transformers_verbosity = transformers_logging.get_verbosity() + self.diffusers_verbosity = diffusers_logging.get_verbosity() + + def __enter__(self): + transformers_logging.set_verbosity_error() + diffusers_logging.set_verbosity_error() + warnings.simplefilter("ignore") + + def __exit__(self, type, value, traceback): + transformers_logging.set_verbosity(self.transformers_verbosity) + diffusers_logging.set_verbosity(self.diffusers_verbosity) + warnings.simplefilter("default") diff --git a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py b/invokeai/backend/model_manager/convert_ckpt_to_diffusers.py similarity index 98% rename from invokeai/backend/model_management/convert_ckpt_to_diffusers.py rename to invokeai/backend/model_manager/convert_ckpt_to_diffusers.py index 0a3a63dad63..ed89d783061 100644 --- a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py +++ b/invokeai/backend/model_manager/convert_ckpt_to_diffusers.py @@ -19,9 +19,8 @@ import re from contextlib import nullcontext -from io import BytesIO from pathlib import Path -from typing import Optional, Union +from typing import Dict, Optional, Union import requests import torch @@ -1223,7 +1222,7 @@ def download_from_original_stable_diffusion_ckpt( # scan model scan_result = scan_file_path(checkpoint_path) if scan_result.infected_files != 0: - raise "The model {checkpoint_path} is potentially infected by malware. Aborting import." + raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.") if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" checkpoint = torch.load(checkpoint_path, map_location=device) @@ -1272,15 +1271,15 @@ def download_from_original_stable_diffusion_ckpt( # only refiner xl has embedder and one text embedders config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml" - original_config_file = BytesIO(requests.get(config_url).content) + original_config_file = requests.get(config_url).text original_config = OmegaConf.load(original_config_file) - if original_config["model"]["params"].get("use_ema") is not None: - extract_ema = original_config["model"]["params"]["use_ema"] + if original_config.model["params"].get("use_ema") is not None: + extract_ema = original_config.model["params"]["use_ema"] if ( model_version in [BaseModelType.StableDiffusion2, BaseModelType.StableDiffusion1] - and original_config["model"]["params"].get("parameterization") == "v" + and original_config.model["params"].get("parameterization") == "v" ): prediction_type = "v_prediction" upcast_attention = True @@ -1312,11 +1311,11 @@ def download_from_original_stable_diffusion_ckpt( num_in_channels = 4 if "unet_config" in original_config.model.params: - original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels + original_config.model["params"]["unet_config"]["params"]["in_channels"] = num_in_channels if ( - "parameterization" in original_config["model"]["params"] - and original_config["model"]["params"]["parameterization"] == "v" + "parameterization" in original_config.model["params"] + and original_config.model["params"]["parameterization"] == "v" ): if prediction_type is None: # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` @@ -1437,7 +1436,7 @@ def download_from_original_stable_diffusion_ckpt( if model_type == "FrozenOpenCLIPEmbedder": config_name = "stabilityai/stable-diffusion-2" - config_kwargs = {"subfolder": "text_encoder"} + config_kwargs: Dict[str, Union[str, int]] = {"subfolder": "text_encoder"} text_model = convert_open_clip_checkpoint(checkpoint, config_name, **config_kwargs) tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "stable-diffusion-2-clip", subfolder="tokenizer") @@ -1664,7 +1663,7 @@ def download_controlnet_from_original_ckpt( # scan model scan_result = scan_file_path(checkpoint_path) if scan_result.infected_files != 0: - raise "The model {checkpoint_path} is potentially infected by malware. Aborting import." + raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.") if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" checkpoint = torch.load(checkpoint_path, map_location=device) @@ -1685,7 +1684,7 @@ def download_controlnet_from_original_ckpt( original_config = OmegaConf.load(original_config_file) if num_in_channels is not None: - original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels + original_config.model["params"]["unet_config"]["params"]["in_channels"] = num_in_channels if "control_stage_config" not in original_config.model.params: raise ValueError("`control_stage_config` not present in original config") @@ -1725,7 +1724,7 @@ def convert_ckpt_to_diffusers( and in addition a path-like object indicating the location of the desired diffusers model to be written. """ - pipe = download_from_original_stable_diffusion_ckpt(checkpoint_path, **kwargs) + pipe = download_from_original_stable_diffusion_ckpt(str(checkpoint_path), **kwargs) pipe.save_pretrained( dump_path, @@ -1743,6 +1742,6 @@ def convert_controlnet_to_diffusers( and in addition a path-like object indicating the location of the desired diffusers model to be written. """ - pipe = download_controlnet_from_original_ckpt(checkpoint_path, **kwargs) + pipe = download_controlnet_from_original_ckpt(str(checkpoint_path), **kwargs) pipe.save_pretrained(dump_path, safe_serialization=True) diff --git a/invokeai/backend/model_manager/download/__init__.py b/invokeai/backend/model_manager/download/__init__.py new file mode 100644 index 00000000000..9d25952035d --- /dev/null +++ b/invokeai/backend/model_manager/download/__init__.py @@ -0,0 +1,11 @@ +"""Initialization file for threaded download manager.""" + +from .base import ( # noqa F401 + DownloadEventHandler, + DownloadJobBase, + DownloadJobStatus, + DownloadQueueBase, + UnknownJobIDException, +) +from .model_queue import ModelDownloadQueue, ModelSourceMetadata # noqa F401 +from .queue import DownloadJobPath, DownloadJobRemoteSource, DownloadJobURL, DownloadQueue # noqa F401 diff --git a/invokeai/backend/model_manager/download/base.py b/invokeai/backend/model_manager/download/base.py new file mode 100644 index 00000000000..8615d6829ac --- /dev/null +++ b/invokeai/backend/model_manager/download/base.py @@ -0,0 +1,260 @@ +# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team +"""Abstract base class for a multithreaded model download queue.""" + +import threading +from abc import ABC, abstractmethod +from enum import Enum +from functools import total_ordering +from pathlib import Path +from typing import Any, Callable, List, Optional, Union + +import requests +from pydantic import BaseModel, Field +from pydantic.networks import AnyHttpUrl + +from invokeai.app.services.config import InvokeAIAppConfig + + +class DownloadJobStatus(str, Enum): + """State of a download job.""" + + IDLE = "idle" # not enqueued, will not run + ENQUEUED = "enqueued" # enqueued but not yet active + RUNNING = "running" # actively downloading + PAUSED = "paused" # previously started, now paused + COMPLETED = "completed" # finished running + ERROR = "error" # terminated with an error message + CANCELLED = "cancelled" # terminated by caller + + +class UnknownJobIDException(Exception): + """Raised when an invalid Job is referenced.""" + + +DownloadEventHandler = Callable[["DownloadJobBase"], None] + + +@total_ordering +class DownloadJobBase(BaseModel): + """Class to monitor and control a model download request.""" + + priority: int = Field(default=10, description="Queue priority; lower values are higher priority") + id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel + source: Any = Field(description="Where to download from. Specific types specified in child classes.") + destination: Path = Field(description="Destination of downloaded model on local disk") + status: DownloadJobStatus = Field(default=DownloadJobStatus.IDLE, description="Status of the download") + event_handlers: Optional[List[DownloadEventHandler]] = Field( + description="Callables that will be called whenever job status changes", + default_factory=list, + ) + job_started: Optional[float] = Field(description="Timestamp for when the download job started") + job_ended: Optional[float] = Field(description="Timestamp for when the download job ended (completed or errored)") + job_sequence: Optional[int] = Field( + description="Counter that records order in which this job was dequeued (used in unit testing)" + ) + preserve_partial_downloads: bool = Field( + description="if true, then preserve partial downloads when cancelled or errored", default=False + ) + error: Optional[Exception] = Field(default=None, description="Exception that caused an error") + + def add_event_handler(self, handler: DownloadEventHandler): + """Add an event handler to the end of the handlers list.""" + if self.event_handlers is not None: + self.event_handlers.append(handler) + + def clear_event_handlers(self): + """Clear all event handlers.""" + self.event_handlers = list() + + def cleanup(self, preserve_partial_downloads: bool = False): + """Possibly do some action when work is finished.""" + pass + + class Config: + """Config object for this pydantic class.""" + + arbitrary_types_allowed = True + validate_assignment = True + + def __lt__(self, other: "DownloadJobBase") -> bool: + """ + Return True if self.priority < other.priority. + + :param other: The DownloadJobBase that this will be compared against. + """ + if not hasattr(other, "priority"): + return NotImplemented + return self.priority < other.priority + + +class DownloadQueueBase(ABC): + """Abstract base class for managing model downloads.""" + + @abstractmethod + def __init__( + self, + max_parallel_dl: int = 5, + event_handlers: List[DownloadEventHandler] = [], + requests_session: Optional[requests.sessions.Session] = None, + quiet: bool = False, + ): + """ + Initialize DownloadQueue. + + :param max_parallel_dl: Number of simultaneous downloads allowed [5]. + :param event_handler: Optional callable that will be called each time a job status changes. + :param requests_session: Optional requests.sessions.Session object, for unit tests. + :param quiet: If true, don't log the start of download jobs. Useful for subrequests. + """ + pass + + @abstractmethod + def create_download_job( + self, + source: Union[str, Path, AnyHttpUrl], + destdir: Path, + priority: int = 10, + start: Optional[bool] = True, + filename: Optional[Path] = None, + variant: Optional[str] = None, # FIXME: variant is only used in one specific subclass + access_token: Optional[str] = None, + event_handlers: List[DownloadEventHandler] = [], + ) -> DownloadJobBase: + """ + Create and submit a download job. + + :param source: Source of the download - URL, repo_id or Path + :param destdir: Directory to download into. + :param priority: Initial priority for this job [10] + :param filename: Optional name of file, if not provided + will use the content-disposition field to assign the name. + :param start: Immediately start job [True] + :param variant: Variant to download, such as "fp16" (repo_ids only). + :param event_handlers: Optional callables that will be called whenever job status changes. + :returns the job: job.id will be a non-negative value after execution + + Known variants currently are: + 1. onnx + 2. openvino + 3. fp16 + 4. None (usually returns fp32 model) + """ + pass + + def submit_download_job( + self, + job: DownloadJobBase, + start: Optional[bool] = True, + ): + """ + Submit a download job. + + :param job: A DownloadJobBase + :param start: Immediately start job [True] + + After execution, `job.id` will be set to a non-negative value. + """ + pass + + @abstractmethod + def release(self): + """ + Release resources used by queue. + + If threaded downloads are + used, then this will stop the threads. + """ + pass + + @abstractmethod + def list_jobs(self) -> List[DownloadJobBase]: + """ + List active DownloadJobBases. + + :returns List[DownloadJobBase]: List of download jobs whose state is not "completed." + """ + pass + + @abstractmethod + def id_to_job(self, id: int) -> DownloadJobBase: + """ + Return the DownloadJobBase corresponding to the string ID. + + :param id: ID of the DownloadJobBase. + + Exceptions: + * UnknownJobException + + Note that once a job is completed, id_to_job() may no longer + recognize the job. Call id_to_job() before the job completes + if you wish to keep the job object around after it has + completed work. + """ + pass + + @abstractmethod + def start_all_jobs(self): + """Enqueue all stopped jobs.""" + pass + + @abstractmethod + def pause_all_jobs(self): + """Pause and dequeue all active jobs.""" + pass + + @abstractmethod + def prune_jobs(self): + """Prune completed and errored queue items from the job list.""" + pass + + @abstractmethod + def cancel_all_jobs(self, preserve_partial: bool = False): + """ + Cancel all jobs (those in enqueued, running and paused states). + + :param preserve_partial: Keep partially downloaded files [False]. + """ + pass + + @abstractmethod + def start_job(self, job: DownloadJobBase): + """Start the job putting it into ENQUEUED state.""" + pass + + @abstractmethod + def pause_job(self, job: DownloadJobBase): + """Pause the job, putting it into PAUSED state.""" + pass + + @abstractmethod + def cancel_job(self, job: DownloadJobBase, preserve_partial: bool = False): + """ + Cancel the job, clearing partial downloads and putting it into CANCELLED state. + + :param preserve_partial: Keep partial downloads [False] + """ + pass + + @abstractmethod + def join(self): + """ + Wait until all jobs are off the queue. + + Note that once a job is completed, id_to_job() will + no longer recognize the job. + """ + pass + + @abstractmethod + def select_downloader(self, job: DownloadJobBase) -> Callable[[DownloadJobBase], None]: + """Based on the job type select the download method.""" + pass + + @abstractmethod + def get_url_for_job(self, job: DownloadJobBase) -> AnyHttpUrl: + """ + Given a job, translate its source field into a downloadable URL. + + Intended to be subclassed to cover various source types. + """ + pass diff --git a/invokeai/backend/model_manager/download/model_queue.py b/invokeai/backend/model_manager/download/model_queue.py new file mode 100644 index 00000000000..1c2e9766169 --- /dev/null +++ b/invokeai/backend/model_manager/download/model_queue.py @@ -0,0 +1,370 @@ +import re +from pathlib import Path +from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union + +from huggingface_hub import HfApi, hf_hub_url +from pydantic import BaseModel, Field, parse_obj_as, validator +from pydantic.networks import AnyHttpUrl + +from .base import DownloadEventHandler, DownloadJobBase, DownloadJobStatus, DownloadQueueBase +from .queue import HTTP_RE, DownloadJobRemoteSource, DownloadJobURL, DownloadQueue + +# regular expressions used to dispatch appropriate downloaders and metadata retrievers +# endpoint for civitai get-model API +CIVITAI_MODEL_DOWNLOAD = r"https://civitai.com/api/download/models/(\d+)" +CIVITAI_MODEL_PAGE = "https://civitai.com/models/" +CIVITAI_MODEL_PAGE_WITH_VERSION = r"https://civitai.com/models/(\d+)\?modelVersionId=(\d+)" +CIVITAI_MODELS_ENDPOINT = "https://civitai.com/api/v1/models/" +CIVITAI_VERSIONS_ENDPOINT = "https://civitai.com/api/v1/model-versions/" + +# Regular expressions to describe repo_ids and http urls +REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE = r"^([.\w-]+/[.\w-]+)(?::([.\w-]+))?$" + + +class ModelSourceMetadata(BaseModel): + """Information collected on a downloadable model from its source site.""" + + name: Optional[str] = Field(description="Human-readable name of this model") + author: Optional[str] = Field(description="Author/creator of the model") + description: Optional[str] = Field(description="Description of the model") + license: Optional[str] = Field(description="Model license terms") + thumbnail_url: Optional[AnyHttpUrl] = Field(description="URL of a thumbnail image for the model") + tags: Optional[List[str]] = Field(description="List of descriptive tags") + + +class DownloadJobWithMetadata(DownloadJobRemoteSource): + """A remote download that has metadata associated with it.""" + + metadata: ModelSourceMetadata = Field( + description="Metadata describing the model, derived from source", default_factory=ModelSourceMetadata + ) + + +class DownloadJobMetadataURL(DownloadJobWithMetadata, DownloadJobURL): + """DownloadJobWithMetadata with validation of the source URL.""" + + +class DownloadJobRepoID(DownloadJobWithMetadata): + """Download repo ids.""" + + source: str = Field(description="A repo_id (foo/bar), or a repo_id with a subfolder (foo/far:v2)") + subfolder: Optional[str] = Field( + description="Provide when the desired model is in a subfolder of the repo_id's distro", default=None + ) + variant: Optional[str] = Field(description="Variant, such as 'fp16', to download") + subqueue: Optional[DownloadQueueBase] = Field( + description="a subqueue used for downloading the individual files in the repo_id", default=None + ) + + @validator("source") + @classmethod + def proper_repo_id(cls, v: str) -> str: # noqa D102 + if not re.match(REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, v): + raise ValueError(f"{v}: invalid repo_id format") + return v + + def cleanup(self, preserve_partial_downloads: bool = False): + """Perform action when job is completed.""" + if self.subqueue: + self.subqueue.cancel_all_jobs(preserve_partial=preserve_partial_downloads) + self.subqueue.release() + + +class ModelDownloadQueue(DownloadQueue): + """Subclass of DownloadQueue, able to retrieve metadata from HuggingFace and Civitai.""" + + def create_download_job( + self, + source: Union[str, Path, AnyHttpUrl], + destdir: Path, + start: bool = True, + priority: int = 10, + filename: Optional[Path] = None, + variant: Optional[str] = None, + access_token: Optional[str] = None, + event_handlers: List[DownloadEventHandler] = [], + ) -> DownloadJobBase: + """Create a download job and return its ID.""" + cls: Optional[Type[DownloadJobBase]] = None + kwargs: Dict[str, Optional[str]] = dict() + + if re.match(HTTP_RE, str(source)): + cls = DownloadJobWithMetadata + kwargs.update(access_token=access_token) + elif re.match(REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, str(source)): + cls = DownloadJobRepoID + kwargs.update( + variant=variant, + access_token=access_token, + ) + if cls: + job = cls( + source=source, + destination=Path(destdir) / (filename or "."), + event_handlers=event_handlers, + priority=priority, + **kwargs, + ) + return self.submit_download_job(job, start) + else: + return super().create_download_job( + source=source, + destdir=destdir, + start=start, + priority=priority, + filename=filename, + variant=variant, + access_token=access_token, + event_handlers=event_handlers, + ) + + def select_downloader(self, job: DownloadJobBase) -> Callable[[DownloadJobBase], None]: + """Based on the job type select the download method.""" + if isinstance(job, DownloadJobRepoID): + return self._download_repoid + elif isinstance(job, DownloadJobWithMetadata): + return self._download_with_resume + else: + return super().select_downloader(job) + + def get_url_for_job(self, job: DownloadJobBase) -> AnyHttpUrl: + """ + Fetch metadata from certain well-known URLs. + + The metadata will be stashed in job.metadata, if found + Return the download URL. + """ + assert isinstance(job, DownloadJobWithMetadata) + metadata = job.metadata + url = job.source + metadata_url = url + model = None + + # a Civitai download URL + if match := re.match(CIVITAI_MODEL_DOWNLOAD, str(metadata_url)): + version = match.group(1) + resp = self._requests.get(CIVITAI_VERSIONS_ENDPOINT + version).json() + metadata.thumbnail_url = metadata.thumbnail_url or resp["images"][0]["url"] + metadata.description = metadata.description or ( + f"Trigger terms: {(', ').join(resp['trainedWords'])}" if resp["trainedWords"] else resp["description"] + ) + metadata_url = CIVITAI_MODEL_PAGE + str(resp["modelId"]) + f"?modelVersionId={version}" + + # a Civitai model page with the version + if match := re.match(CIVITAI_MODEL_PAGE_WITH_VERSION, str(metadata_url)): + model = match.group(1) + version = int(match.group(2)) + # and without + elif match := re.match(CIVITAI_MODEL_PAGE + r"(\d+)", str(metadata_url)): + model = match.group(1) + version = None + + if not model: + return parse_obj_as(AnyHttpUrl, url) + + if model: + resp = self._requests.get(CIVITAI_MODELS_ENDPOINT + str(model)).json() + + metadata.author = metadata.author or resp["creator"]["username"] + metadata.tags = metadata.tags or resp["tags"] + metadata.license = ( + metadata.license + or f"allowCommercialUse={resp['allowCommercialUse']}; allowDerivatives={resp['allowDerivatives']}; allowNoCredit={resp['allowNoCredit']}" + ) + + if version: + versions = [x for x in resp["modelVersions"] if int(x["id"]) == version] + version_data = versions[0] + else: + version_data = resp["modelVersions"][0] # first one + + metadata.thumbnail_url = version_data.get("url") or metadata.thumbnail_url + metadata.description = metadata.description or ( + f"Trigger terms: {(', ').join(version_data.get('trainedWords'))}" + if version_data.get("trainedWords") + else version_data.get("description") + ) + + download_url = version_data["downloadUrl"] + + # return the download url + return download_url + + def _download_repoid(self, job: DownloadJobBase) -> None: + """Download a job that holds a huggingface repoid.""" + + def subdownload_event(subjob: DownloadJobBase): + assert isinstance(subjob, DownloadJobRemoteSource) + assert isinstance(job, DownloadJobRemoteSource) + if job.status != DownloadJobStatus.RUNNING: # do not update if we are cancelled or paused + return + if subjob.status == DownloadJobStatus.RUNNING: + bytes_downloaded[subjob.id] = subjob.bytes + job.bytes = sum(bytes_downloaded.values()) + self._update_job_status(job, DownloadJobStatus.RUNNING) + return + + if subjob.status == DownloadJobStatus.ERROR: + job.error = subjob.error + job.cleanup() + self._update_job_status(job, DownloadJobStatus.ERROR) + return + + if subjob.status == DownloadJobStatus.COMPLETED: + bytes_downloaded[subjob.id] = subjob.bytes + job.bytes = sum(bytes_downloaded.values()) + self._update_job_status(job, DownloadJobStatus.RUNNING) + return + + assert isinstance(job, DownloadJobRepoID) + self._lock.acquire() # prevent status from being updated while we are setting up subqueue + self._update_job_status(job, DownloadJobStatus.RUNNING) + try: + job.subqueue = self.__class__( + event_handlers=[subdownload_event], + requests_session=self._requests, + quiet=True, + ) + repo_id = job.source + variant = job.variant + if not job.metadata: + job.metadata = ModelSourceMetadata() + urls_to_download = self._get_repo_info( + repo_id, variant=variant, metadata=job.metadata, subfolder=job.subfolder + ) + if job.destination.name != Path(repo_id).name: + job.destination = job.destination / Path(repo_id).name + bytes_downloaded: Dict[int, int] = dict() + job.total_bytes = 0 + + for url, subdir, file, size in urls_to_download: + job.total_bytes += size + job.subqueue.create_download_job( + source=url, + destdir=job.destination / subdir, + filename=file, + variant=variant, + access_token=job.access_token, + ) + except KeyboardInterrupt as excp: + raise excp + except Exception as excp: + job.error = excp + self._update_job_status(job, DownloadJobStatus.ERROR) + self._logger.error(job.error) + finally: + self._lock.release() + if job.subqueue is not None: + job.subqueue.join() + if job.status == DownloadJobStatus.RUNNING: + self._update_job_status(job, DownloadJobStatus.COMPLETED) + + def _get_repo_info( + self, + repo_id: str, + metadata: ModelSourceMetadata, + variant: Optional[str] = None, + subfolder: Optional[str] = None, + ) -> List[Tuple[AnyHttpUrl, Path, Path, int]]: + """ + Given a repo_id and an optional variant, return list of URLs to download to get the model. + + The metadata field will be updated with model metadata from HuggingFace. + + Known variants currently are: + 1. onnx + 2. openvino + 3. fp16 + 4. None (usually returns fp32 model) + """ + model_info = HfApi().model_info(repo_id=repo_id, files_metadata=True) + sibs = model_info.siblings + paths = [] + + # unfortunately the HF repo contains both files needed for the model + # as well as anything else the owner thought to include in the directory, + # including checkpoint files, different EMA versions, etc. + # This filters out just the file types needed for the model + for x in sibs: + if x.rfilename.endswith((".json", ".txt")): + paths.append(x.rfilename) + elif x.rfilename.endswith(("learned_embeds.bin", "ip_adapter.bin")): + paths.append(x.rfilename) + elif re.search(r"model(\.[^.]+)?\.(safetensors|bin)$", x.rfilename): + paths.append(x.rfilename) + + sizes = {x.rfilename: x.size for x in sibs} + + prefix = "" + if subfolder: + prefix = f"{subfolder}/" + paths = [x for x in paths if x.startswith(prefix)] + + if f"{prefix}model_index.json" in paths: + url = hf_hub_url(repo_id, filename="model_index.json", subfolder=subfolder) + resp = self._requests.get(url) + resp.raise_for_status() # will raise an HTTPError on non-200 status + submodels = resp.json() + paths = [Path(subfolder or "", x) for x in paths if Path(x).parent.as_posix() in submodels] + paths.insert(0, f"{prefix}model_index.json") + urls = [ + ( + hf_hub_url(repo_id, filename=x.as_posix()), + x.parent.relative_to(prefix) or Path("."), + Path(x.name), + sizes[x.as_posix()], + ) + for x in self._select_variants(paths, variant) + ] + if hasattr(model_info, "cardData"): + metadata.license = metadata.license or model_info.cardData.get("license") + metadata.tags = metadata.tags or model_info.tags + metadata.author = metadata.author or model_info.author + return urls + + def _select_variants(self, paths: List[str], variant: Optional[str] = None) -> Set[Path]: + """Select the proper variant files from a list of HuggingFace repo_id paths.""" + result = set() + basenames: Dict[Path, Path] = dict() + for p in paths: + path = Path(p) + + if path.suffix == ".onnx": + if variant == "onnx": + result.add(path) + + elif path.name.startswith("openvino_model"): + if variant == "openvino": + result.add(path) + + elif path.suffix in [".json", ".txt"]: + result.add(path) + + elif path.suffix in [".bin", ".safetensors", ".pt"] and variant in ["fp16", None]: + parent = path.parent + suffixes = path.suffixes + if len(suffixes) == 2: + file_variant, suffix = suffixes + basename = parent / Path(path.stem).stem + else: + file_variant = None + suffix = suffixes[0] + basename = parent / path.stem + + if previous := basenames.get(basename): + if previous.suffix != ".safetensors" and suffix == ".safetensors": + basenames[basename] = path + if file_variant == f".{variant}": + basenames[basename] = path + elif not variant and not file_variant: + basenames[basename] = path + else: + basenames[basename] = path + + else: + continue + + for v in basenames.values(): + result.add(v) + + return result diff --git a/invokeai/backend/model_manager/download/queue.py b/invokeai/backend/model_manager/download/queue.py new file mode 100644 index 00000000000..b36066c90f9 --- /dev/null +++ b/invokeai/backend/model_manager/download/queue.py @@ -0,0 +1,432 @@ +# Copyright (c) 2023, Lincoln D. Stein +"""Implementation of multithreaded download queue for invokeai.""" + +import os +import re +import shutil +import threading +import time +import traceback +from pathlib import Path +from queue import PriorityQueue +from typing import Callable, Dict, List, Optional, Set, Union + +import requests +from pydantic import Field +from pydantic.networks import AnyHttpUrl +from requests import HTTPError + +from invokeai.backend.util import InvokeAILogger, Logger + +from .base import DownloadEventHandler, DownloadJobBase, DownloadJobStatus, DownloadQueueBase, UnknownJobIDException + +# Maximum number of bytes to download during each call to requests.iter_content() +DOWNLOAD_CHUNK_SIZE = 100000 + +# marker that the queue is done and that thread should exit +STOP_JOB = DownloadJobBase(id=-99, priority=-99, source="dummy", destination="/") + +# regular expression for picking up a URL +HTTP_RE = r"^https?://" + + +class DownloadJobPath(DownloadJobBase): + """Download from a local Path.""" + + source: Path = Field(description="Local filesystem Path where model can be found") + + +class DownloadJobRemoteSource(DownloadJobBase): + """A DownloadJob from a remote source that provides progress info.""" + + bytes: int = Field(default=0, description="Bytes downloaded so far") + total_bytes: int = Field(default=0, description="Total bytes to download") + access_token: Optional[str] = Field(description="access token needed to access this resource") + + +class DownloadJobURL(DownloadJobRemoteSource): + """Job declaration for downloading individual URLs.""" + + source: AnyHttpUrl = Field(description="URL to download") + + +class DownloadQueue(DownloadQueueBase): + """Class for queued download of models.""" + + _jobs: Dict[int, DownloadJobBase] + _worker_pool: Set[threading.Thread] + _queue: PriorityQueue + _lock: threading.RLock # to allow for reentrant locking for method calls + _logger: Logger + _event_handlers: List[DownloadEventHandler] = Field(default_factory=list) + _next_job_id: int = 0 + _sequence: int = 0 # This is for debugging and used to tag jobs in dequeueing order + _requests: requests.sessions.Session + _quiet: bool = False + + def __init__( + self, + max_parallel_dl: int = 5, + event_handlers: List[DownloadEventHandler] = [], + requests_session: Optional[requests.sessions.Session] = None, + quiet: bool = False, + ): + """ + Initialize DownloadQueue. + + :param max_parallel_dl: Number of simultaneous downloads allowed [5]. + :param event_handler: Optional callable that will be called each time a job status changes. + :param requests_session: Optional requests.sessions.Session object, for unit tests. + """ + self._jobs = dict() + self._next_job_id = 0 + self._queue = PriorityQueue() + self._worker_pool = set() + self._lock = threading.RLock() + self._logger = InvokeAILogger.get_logger() + self._event_handlers = event_handlers + self._requests = requests_session or requests.Session() + self._quiet = quiet + + self._start_workers(max_parallel_dl) + + def create_download_job( + self, + source: Union[str, Path, AnyHttpUrl], + destdir: Path, + start: bool = True, + priority: int = 10, + filename: Optional[Path] = None, + variant: Optional[str] = None, + access_token: Optional[str] = None, + event_handlers: List[DownloadEventHandler] = [], + ) -> DownloadJobBase: + """Create a download job and return its ID.""" + kwargs: Dict[str, Optional[str]] = dict() + + cls = DownloadJobBase + if Path(source).exists(): + cls = DownloadJobPath + elif re.match(HTTP_RE, str(source)): + cls = DownloadJobURL + kwargs.update(access_token=access_token) + else: + raise NotImplementedError(f"Don't know what to do with this type of source: {source}") + + job = cls( + source=source, + destination=Path(destdir) / (filename or "."), + event_handlers=event_handlers, + priority=priority, + **kwargs, + ) + + return self.submit_download_job(job, start) + + def submit_download_job( + self, + job: DownloadJobBase, + start: Optional[bool] = True, + ): + """Submit a job.""" + # add the queue's handlers + for handler in self._event_handlers: + job.add_event_handler(handler) + with self._lock: + job.id = self._next_job_id + self._jobs[job.id] = job + self._next_job_id += 1 + if start: + self.start_job(job) + return job + + def release(self): + """Signal our threads to exit when queue done.""" + for thread in self._worker_pool: + if thread.is_alive(): + self._queue.put(STOP_JOB) + + def join(self): + """Wait for all jobs to complete.""" + self._queue.join() + + def list_jobs(self) -> List[DownloadJobBase]: + """List all the jobs.""" + return list(self._jobs.values()) + + def prune_jobs(self): + """Prune completed and errored queue items from the job list.""" + with self._lock: + to_delete = set() + try: + for job_id, job in self._jobs.items(): + if self._in_terminal_state(job): + to_delete.add(job_id) + for job_id in to_delete: + del self._jobs[job_id] + except KeyError as excp: + raise UnknownJobIDException("Unrecognized job") from excp + + def id_to_job(self, id: int) -> DownloadJobBase: + """Translate a job ID into a DownloadJobBase object.""" + try: + return self._jobs[id] + except KeyError as excp: + raise UnknownJobIDException("Unrecognized job") from excp + + def start_job(self, job: DownloadJobBase): + """Enqueue (start) the indicated job.""" + with self._lock: + try: + assert isinstance(self._jobs[job.id], DownloadJobBase) + self._update_job_status(job, DownloadJobStatus.ENQUEUED) + self._queue.put(job) + except (AssertionError, KeyError) as excp: + raise UnknownJobIDException("Unrecognized job") from excp + + def pause_job(self, job: DownloadJobBase): + """ + Pause (dequeue) the indicated job. + + The job can be restarted with start_job() and the download will pick up + from where it left off. + """ + with self._lock: + try: + assert isinstance(self._jobs[job.id], DownloadJobBase) + self._update_job_status(job, DownloadJobStatus.PAUSED) + job.cleanup() + except (AssertionError, KeyError) as excp: + raise UnknownJobIDException("Unrecognized job") from excp + + def cancel_job(self, job: DownloadJobBase, preserve_partial: bool = False): + """ + Cancel the indicated job. + + If it is running it will be stopped. + job.status will be set to DownloadJobStatus.CANCELLED + """ + with self._lock: + try: + assert isinstance(self._jobs[job.id], DownloadJobBase) + job.preserve_partial_downloads = preserve_partial + self._update_job_status(job, DownloadJobStatus.CANCELLED) + job.cleanup() + except (AssertionError, KeyError) as excp: + raise UnknownJobIDException("Unrecognized job") from excp + + def start_all_jobs(self): + """Start (enqueue) all jobs that are idle or paused.""" + with self._lock: + for job in self._jobs.values(): + if job.status in [DownloadJobStatus.IDLE, DownloadJobStatus.PAUSED]: + self.start_job(job) + + def pause_all_jobs(self): + """Pause all running jobs.""" + with self._lock: + for job in self._jobs.values(): + if not self._in_terminal_state(job): + self.pause_job(job) + + def cancel_all_jobs(self, preserve_partial: bool = False): + """Cancel all jobs (those not in enqueued, running or paused state).""" + with self._lock: + for job in self._jobs.values(): + if not self._in_terminal_state(job): + self.cancel_job(job, preserve_partial) + + def _in_terminal_state(self, job: DownloadJobBase): + return job.status in [ + DownloadJobStatus.COMPLETED, + DownloadJobStatus.ERROR, + DownloadJobStatus.CANCELLED, + ] + + def _start_workers(self, max_workers: int): + """Start the requested number of worker threads.""" + for i in range(0, max_workers): + worker = threading.Thread(target=self._download_next_item, daemon=True) + worker.start() + self._worker_pool.add(worker) + + def _download_next_item(self): + """Worker thread gets next job on priority queue.""" + done = False + while not done: + job = self._queue.get() + + with self._lock: + job.job_sequence = self._sequence + self._sequence += 1 + + try: + if job == STOP_JOB: # marker that queue is done + done = True + + if job.status == DownloadJobStatus.ENQUEUED: + if not self._quiet: + self._logger.info(f"{job.source}: Downloading to {job.destination}") + do_download = self.select_downloader(job) + do_download(job) + + if job.status == DownloadJobStatus.CANCELLED: + self._cleanup_cancelled_job(job) + finally: + self._queue.task_done() + + def select_downloader(self, job: DownloadJobBase) -> Callable[[DownloadJobBase], None]: + """Based on the job type select the download method.""" + if isinstance(job, DownloadJobURL): + return self._download_with_resume + elif isinstance(job, DownloadJobPath): + return self._download_path + else: + raise NotImplementedError(f"Don't know what to do with this job: {job}, type={type(job)}") + + def get_url_for_job(self, job: DownloadJobBase) -> AnyHttpUrl: + return job.source + + def _download_with_resume(self, job: DownloadJobBase): + """Do the actual download.""" + dest = None + try: + assert isinstance(job, DownloadJobRemoteSource) + url = self.get_url_for_job(job) + header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {} + open_mode = "wb" + exist_size = 0 + + resp = self._requests.get(url, headers=header, stream=True) + content_length = int(resp.headers.get("content-length", 0)) + job.total_bytes = content_length + + if job.destination.is_dir(): + try: + file_name = "" + if match := re.search('filename="(.+)"', resp.headers["Content-Disposition"]): + file_name = match.group(1) + assert file_name != "" + self._validate_filename( + job.destination.as_posix(), file_name + ) # will raise a ValueError exception if file_name is suspicious + except ValueError: + self._logger.warning( + f"Invalid filename '{file_name}' returned by source {url}, using last component of URL instead" + ) + file_name = os.path.basename(url) + except (KeyError, AssertionError): + file_name = os.path.basename(url) + job.destination = job.destination / file_name + dest = job.destination + else: + dest = job.destination + dest.parent.mkdir(parents=True, exist_ok=True) + + if dest.exists(): + job.bytes = dest.stat().st_size + header["Range"] = f"bytes={job.bytes}-" + open_mode = "ab" + resp = self._requests.get(url, headers=header, stream=True) # new request with range + + if exist_size > content_length: + self._logger.warning("corrupt existing file found. re-downloading") + os.remove(dest) + exist_size = 0 + + if resp.status_code == 416 or (content_length > 0 and exist_size == content_length): + self._logger.warning(f"{dest}: complete file found. Skipping.") + self._update_job_status(job, DownloadJobStatus.COMPLETED) + return + + if resp.status_code == 206 or exist_size > 0: + self._logger.warning(f"{dest}: partial file found. Resuming") + elif resp.status_code != 200: + raise HTTPError(resp.reason) + else: + self._logger.debug(f"{job.source}: Downloading {job.destination}") + + report_delta = job.total_bytes / 100 # report every 1% change + last_report_bytes = 0 + + self._update_job_status(job, DownloadJobStatus.RUNNING) + with open(dest, open_mode) as file: + for data in resp.iter_content(chunk_size=DOWNLOAD_CHUNK_SIZE): + if job.status != DownloadJobStatus.RUNNING: # cancelled, paused or errored + return + job.bytes += file.write(data) + if job.bytes - last_report_bytes >= report_delta: + last_report_bytes = job.bytes + self._update_job_status(job) + if job.status != DownloadJobStatus.RUNNING: # cancelled, paused or errored + return + self._update_job_status(job, DownloadJobStatus.COMPLETED) + except KeyboardInterrupt as excp: + raise excp + except (HTTPError, OSError) as excp: + self._logger.error(f"An error occurred while downloading/installing {job.source}: {str(excp)}") + print(traceback.format_exc()) + job.error = excp + self._update_job_status(job, DownloadJobStatus.ERROR) + + def _validate_filename(self, directory: str, filename: str): + pc_name_max = os.pathconf(directory, "PC_NAME_MAX") if hasattr(os, "pathconf") else 260 + if "/" in filename: + raise ValueError + if filename.startswith(".."): + raise ValueError + if len(filename) > pc_name_max: + raise ValueError + if len(os.path.join(directory, filename)) > os.pathconf(directory, "PC_PATH_MAX"): + raise ValueError + + def _update_job_status(self, job: DownloadJobBase, new_status: Optional[DownloadJobStatus] = None): + """Optionally change the job status and send an event indicating a change of state.""" + with self._lock: + if new_status: + job.status = new_status + + if self._in_terminal_state(job) and not self._quiet: + self._logger.info(f"{job.source}: Download job completed with status {job.status.value}") + + if new_status == DownloadJobStatus.RUNNING and not job.job_started: + job.job_started = time.time() + elif new_status in [DownloadJobStatus.COMPLETED, DownloadJobStatus.ERROR]: + job.job_ended = time.time() + + if job.event_handlers: + for handler in job.event_handlers: + try: + handler(job) + except KeyboardInterrupt as excp: + raise excp + except Exception as excp: + job.error = excp + if job.status != DownloadJobStatus.ERROR: # let handlers know, but don't cause infinite recursion + self._update_job_status(job, DownloadJobStatus.ERROR) + + def _download_path(self, job: DownloadJobBase): + """Call when the source is a Path or pathlike object.""" + source = Path(job.source).resolve() + destination = Path(job.destination).resolve() + try: + self._update_job_status(job, DownloadJobStatus.RUNNING) + if source != destination: + shutil.move(source, destination) + self._update_job_status(job, DownloadJobStatus.COMPLETED) + except OSError as excp: + job.error = excp + self._update_job_status(job, DownloadJobStatus.ERROR) + + def _cleanup_cancelled_job(self, job: DownloadJobBase): + job.cleanup(job.preserve_partial_downloads) + if not job.preserve_partial_downloads: + self._logger.warning(f"Cleaning up leftover files from cancelled download job {job.destination}") + dest = Path(job.destination) + try: + if dest.is_file(): + dest.unlink() + elif dest.is_dir(): + shutil.rmtree(dest.as_posix(), ignore_errors=True) + except OSError as excp: + self._logger(excp) diff --git a/invokeai/backend/model_manager/hash.py b/invokeai/backend/model_manager/hash.py new file mode 100644 index 00000000000..e445fa03abe --- /dev/null +++ b/invokeai/backend/model_manager/hash.py @@ -0,0 +1,68 @@ +# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team +""" +Fast hashing of diffusers and checkpoint-style models. + +Usage: +from invokeai.backend.model_managre.model_hash import FastModelHash +>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5') +'a8e693a126ea5b831c96064dc569956f' +""" + +import hashlib +import os +from pathlib import Path +from typing import Dict, Union + +from imohash import hashfile + +from .models import InvalidModelException + + +class FastModelHash(object): + """FastModelHash obect provides one public class method, hash().""" + + @classmethod + def hash(cls, model_location: Union[str, Path]) -> str: + """ + Return hexdigest string for model located at model_location. + + :param model_location: Path to the model + """ + model_location = Path(model_location) + if model_location.is_file(): + return cls._hash_file(model_location) + elif model_location.is_dir(): + return cls._hash_dir(model_location) + else: + raise InvalidModelException(f"Not a valid file or directory: {model_location}") + + @classmethod + def _hash_file(cls, model_location: Union[str, Path]) -> str: + """ + Fasthash a single file and return its hexdigest. + + :param model_location: Path to the model file + """ + # we return md5 hash of the filehash to make it shorter + # cryptographic security not needed here + return hashlib.md5(hashfile(model_location)).hexdigest() + + @classmethod + def _hash_dir(cls, model_location: Union[str, Path]) -> str: + components: Dict[str, str] = {} + + for root, dirs, files in os.walk(model_location): + for file in files: + # only tally tensor files because diffusers config files change slightly + # depending on how the model was downloaded/converted. + if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")): + continue + path = (Path(root) / file).as_posix() + fast_hash = cls._hash_file(path) + components.update({path: fast_hash}) + + # hash all the model hashes together, using alphabetic file order + md5 = hashlib.md5() + for path, fast_hash in sorted(components.items()): + md5.update(fast_hash.encode("utf-8")) + return md5.hexdigest() diff --git a/invokeai/backend/model_management/libc_util.py b/invokeai/backend/model_manager/libc_util.py similarity index 100% rename from invokeai/backend/model_management/libc_util.py rename to invokeai/backend/model_manager/libc_util.py diff --git a/invokeai/backend/model_manager/loader.py b/invokeai/backend/model_manager/loader.py new file mode 100644 index 00000000000..85a1b189a1e --- /dev/null +++ b/invokeai/backend/model_manager/loader.py @@ -0,0 +1,250 @@ +# Copyright (c) 2023, Lincoln D. Stein +"""Model loader for InvokeAI.""" + +import hashlib +from abc import ABC, abstractmethod +from dataclasses import dataclass +from pathlib import Path +from shutil import move, rmtree +from typing import Optional, Tuple, Union + +import torch + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.model_record_service import ModelRecordServiceBase +from invokeai.backend.util import InvokeAILogger, Logger, choose_precision, choose_torch_device + +from .cache import CacheStats, ModelCache +from .config import BaseModelType, ModelConfigBase, ModelType, SubModelType +from .models import MODEL_CLASSES, InvalidModelException, ModelBase +from .storage import ModelConfigStore + + +@dataclass +class ModelInfo: + """This is a context manager object that is used to intermediate access to a model.""" + + context: ModelCache.ModelLocker + name: str + base_model: BaseModelType + type: Union[ModelType, SubModelType] + key: str + location: Union[Path, str] + precision: torch.dtype + _cache: Optional[ModelCache] = None + + def __enter__(self): + """Context entry.""" + return self.context.__enter__() + + def __exit__(self, *args, **kwargs): + """Context exit.""" + self.context.__exit__(*args, **kwargs) + + +class ModelLoadBase(ABC): + """Abstract base class for a model loader which works with the ModelConfigStore backend.""" + + @abstractmethod + def get_model(self, key: str, submodel_type: Optional[SubModelType] = None) -> ModelInfo: + """ + Return a model given its key. + + Given a model key identified in the model configuration backend, + return a ModelInfo object that can be used to retrieve the model. + + :param key: model key, as known to the config backend + :param submodel_type: an ModelType enum indicating the portion of + the model to retrieve (e.g. ModelType.Vae) + """ + pass + + @property + @abstractmethod + def store(self) -> ModelConfigStore: + """Return the ModelConfigStore object that supports this loader.""" + pass + + @property + @abstractmethod + def logger(self) -> Logger: + """Return the current logger.""" + pass + + @property + @abstractmethod + def config(self) -> InvokeAIAppConfig: + """Return the config object used by the loader.""" + pass + + @abstractmethod + def collect_cache_stats(self, cache_stats: CacheStats): + """Replace cache statistics.""" + pass + + @abstractmethod + def resolve_model_path(self, path: Union[Path, str]) -> Path: + """Turn a potentially relative path into an absolute one in the models_dir.""" + pass + + @property + @abstractmethod + def precision(self) -> torch.dtype: + """Return torch.fp16 or torch.fp32.""" + pass + + +class ModelLoad(ModelLoadBase): + """Implementation of ModelLoadBase.""" + + _app_config: InvokeAIAppConfig + _store: ModelConfigStore + _cache: ModelCache + _logger: Logger + _cache_keys: dict + + def __init__( + self, + config: InvokeAIAppConfig, + store: Optional[ModelConfigStore] = None, + ): + """ + Initialize ModelLoad object. + + :param config: The app's InvokeAIAppConfig object. + """ + self._app_config = config + self._store = store or ModelRecordServiceBase.open(config) + self._logger = InvokeAILogger.get_logger() + self._cache_keys = dict() + device = torch.device(choose_torch_device()) + device_name = torch.cuda.get_device_name() if device == torch.device("cuda") else "" + precision = choose_precision(device) if config.precision == "auto" else config.precision + dtype = torch.float32 if precision == "float32" else torch.float16 + + self._logger.info(f"Rendering device = {device} ({device_name})") + self._logger.info(f"Maximum RAM cache size: {config.ram}") + self._logger.info(f"Maximum VRAM cache size: {config.vram}") + self._logger.info(f"Precision: {precision}") + + self._cache = ModelCache( + max_cache_size=config.ram, + max_vram_cache_size=config.vram, + lazy_offloading=config.lazy_offload, + execution_device=device, + precision=dtype, + logger=self._logger, + ) + + @property + def store(self) -> ModelConfigStore: + """Return the ModelConfigStore instance used by this class.""" + return self._store + + @property + def precision(self) -> torch.dtype: + """Return torch.fp16 or torch.fp32.""" + return self._cache.precision + + @property + def logger(self) -> Logger: + """Return the current logger.""" + return self._logger + + @property + def config(self) -> InvokeAIAppConfig: + """Return the config object.""" + return self._app_config + + def get_model(self, key: str, submodel_type: Optional[SubModelType] = None) -> ModelInfo: + """ + Get the ModelInfo corresponding to the model with key "key". + + Given a model key identified in the model configuration backend, + return a ModelInfo object that can be used to retrieve the model. + + :param key: model key, as known to the config backend + :param submodel_type: an ModelType enum indicating the portion of + the model to retrieve (e.g. ModelType.Vae) + """ + model_config = self.store.get_model(key) # May raise a UnknownModelException + if model_config.model_type == "main" and not submodel_type: + raise InvalidModelException("submodel_type is required when loading a main model") + + submodel_type = SubModelType(submodel_type) if submodel_type else None + + model_path, is_submodel_override = self._get_model_path(model_config, submodel_type) + + if is_submodel_override: + submodel_type = None + + model_class = self._get_implementation(model_config.base_model, model_config.model_type) + if not model_path.exists(): + raise InvalidModelException(f"Files for model '{key}' not found at {model_path}") + + dst_convert_path = self._get_model_convert_cache_path(model_path) + model_path = self.resolve_model_path( + model_class.convert_if_required( + model_config=model_config, + output_path=dst_convert_path, + ) + ) + + model_context = self._cache.get_model( + model_path=model_path, + model_class=model_class, + base_model=model_config.base_model, + model_type=model_config.model_type, + submodel=submodel_type, + ) + + if key not in self._cache_keys: + self._cache_keys[key] = set() + self._cache_keys[key].add(model_context.key) + + return ModelInfo( + context=model_context, + name=model_config.name, + base_model=model_config.base_model, + type=submodel_type or model_config.model_type, + key=model_config.key, + location=model_path, + precision=self._cache.precision, + _cache=self._cache, + ) + + def collect_cache_stats(self, cache_stats: CacheStats): + """Save CacheStats object for stats collecting.""" + self._cache.stats = cache_stats + + def resolve_model_path(self, path: Union[Path, str]) -> Path: + """Turn a potentially relative path into an absolute one in the models_dir.""" + return self._app_config.models_path / path + + def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]: + """Get the concrete implementation class for a specific model type.""" + model_class = MODEL_CLASSES[base_model][model_type] + return model_class + + def _get_model_convert_cache_path(self, model_path): + return self.resolve_model_path(Path(".cache") / hashlib.md5(str(model_path).encode()).hexdigest()) + + def _get_model_path( + self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None + ) -> Tuple[Path, bool]: + """Extract a model's filesystem path from its config. + + :return: The fully qualified Path of the module (or submodule). + """ + model_path = Path(model_config.path) + is_submodel_override = False + + # Does the config explicitly override the submodel? + if submodel_type is not None and hasattr(model_config, submodel_type): + submodel_path = getattr(model_config, submodel_type) + if submodel_path is not None and len(submodel_path) > 0: + model_path = getattr(model_config, submodel_type) + is_submodel_override = True + + model_path = self.resolve_model_path(model_path) + return model_path, is_submodel_override diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_manager/lora.py similarity index 97% rename from invokeai/backend/model_management/lora.py rename to invokeai/backend/model_manager/lora.py index bb44455c886..602d7f46380 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_manager/lora.py @@ -12,7 +12,7 @@ from safetensors.torch import load_file from transformers import CLIPTextModel, CLIPTokenizer -from .models.lora import LoRAModel +from .models.lora import LoRALayerBase, LoRAModel, LoRAModelRaw """ loras = [ @@ -87,7 +87,7 @@ def apply_lora_unet( def apply_lora_text_encoder( cls, text_encoder: CLIPTextModel, - loras: List[Tuple[LoRAModel, float]], + loras: List[Tuple[LoRAModelRaw, float]], ): with cls.apply_lora(text_encoder, loras, "lora_te_"): yield @@ -97,7 +97,7 @@ def apply_lora_text_encoder( def apply_sdxl_lora_text_encoder( cls, text_encoder: CLIPTextModel, - loras: List[Tuple[LoRAModel, float]], + loras: List[Tuple[LoRAModelRaw, float]], ): with cls.apply_lora(text_encoder, loras, "lora_te1_"): yield @@ -107,7 +107,7 @@ def apply_sdxl_lora_text_encoder( def apply_sdxl_lora_text_encoder2( cls, text_encoder: CLIPTextModel, - loras: List[Tuple[LoRAModel, float]], + loras: List[Tuple[LoRAModelRaw, float]], ): with cls.apply_lora(text_encoder, loras, "lora_te2_"): yield @@ -117,7 +117,7 @@ def apply_sdxl_lora_text_encoder2( def apply_lora( cls, model: torch.nn.Module, - loras: List[Tuple[LoRAModel, float]], + loras: List[Tuple[LoRAModelRaw, float]], prefix: str, ): original_weights = dict() @@ -337,7 +337,7 @@ def apply_lora_text_encoder( def apply_lora( cls, model: IAIOnnxRuntimeModel, - loras: List[Tuple[LoRAModel, float]], + loras: List[Tuple[LoRAModelRaw, torch.Tensor]], prefix: str, ): from .models.base import IAIOnnxRuntimeModel @@ -348,7 +348,7 @@ def apply_lora( orig_weights = dict() try: - blended_loras = dict() + blended_loras: Dict[str, torch.Tensor] = dict() for lora, lora_weight in loras: for layer_key, layer in lora.layers.items(): diff --git a/invokeai/backend/model_management/memory_snapshot.py b/invokeai/backend/model_manager/memory_snapshot.py similarity index 97% rename from invokeai/backend/model_management/memory_snapshot.py rename to invokeai/backend/model_manager/memory_snapshot.py index 4f43affcf72..8c8d0008315 100644 --- a/invokeai/backend/model_management/memory_snapshot.py +++ b/invokeai/backend/model_manager/memory_snapshot.py @@ -4,7 +4,7 @@ import psutil import torch -from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2 +from .libc_util import LibcUtil, Struct_mallinfo2 GB = 2**30 # 1 GB diff --git a/invokeai/backend/model_management/model_merge.py b/invokeai/backend/model_manager/merge.py similarity index 64% rename from invokeai/backend/model_management/model_merge.py rename to invokeai/backend/model_manager/merge.py index 59201d64d98..dbeb4523b93 100644 --- a/invokeai/backend/model_management/model_merge.py +++ b/invokeai/backend/model_manager/merge.py @@ -1,5 +1,5 @@ """ -invokeai.backend.model_management.model_merge exports: +invokeai.backend.model_manager.merge exports: merge_diffusion_models() -- combine multiple models by location and return a pipeline object merge_diffusion_models_and_commit() -- combine multiple models by ModelManager ID and write to models.yaml @@ -9,14 +9,17 @@ import warnings from enum import Enum from pathlib import Path -from typing import List, Optional, Union +from typing import List, Optional, Set from diffusers import DiffusionPipeline from diffusers import logging as dlogging import invokeai.backend.util.logging as logger +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.model_install_service import ModelInstallService -from ...backend.model_management import AddModelResult, BaseModelType, ModelManager, ModelType, ModelVariantType +from . import BaseModelType, ModelConfigBase, ModelConfigStore, ModelType +from .config import MainConfig class MergeInterpolationMethod(str, Enum): @@ -27,8 +30,18 @@ class MergeInterpolationMethod(str, Enum): class ModelMerger(object): - def __init__(self, manager: ModelManager): - self.manager = manager + _store: ModelConfigStore + _config: InvokeAIAppConfig + + def __init__(self, store: ModelConfigStore, config: Optional[InvokeAIAppConfig] = None): + """ + Initialize a ModelMerger object. + + :param store: Underlying storage manager for the running process. + :param config: InvokeAIAppConfig object (if not provided, default will be selected). + """ + self._store = store + self._config = config or InvokeAIAppConfig.get_config() def merge_diffusion_models( self, @@ -70,15 +83,14 @@ def merge_diffusion_models( def merge_diffusion_models_and_save( self, - model_names: List[str], - base_model: Union[BaseModelType, str], + model_keys: List[str], merged_model_name: str, - alpha: float = 0.5, + alpha: Optional[float] = 0.5, interp: Optional[MergeInterpolationMethod] = None, - force: bool = False, + force: Optional[bool] = False, merge_dest_directory: Optional[Path] = None, **kwargs, - ) -> AddModelResult: + ) -> ModelConfigBase: """ :param models: up to three models, designated by their InvokeAI models.yaml model name :param base_model: base model (must be the same for all merged models!) @@ -92,25 +104,38 @@ def merge_diffusion_models_and_save( **kwargs - the default DiffusionPipeline.get_config_dict kwargs: cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map """ - model_paths = list() - config = self.manager.app_config - base_model = BaseModelType(base_model) + model_paths: List[Path] = list() + model_names = list() + config = self._config + store = self._store + base_models: Set[BaseModelType] = set() vae = None - for mod in model_names: - info = self.manager.list_model(mod, base_model=base_model, model_type=ModelType.Main) - assert info, f"model {mod}, base_model {base_model}, is unknown" + assert ( + len(model_keys) <= 2 or interp == MergeInterpolationMethod.AddDifference + ), "When merging three models, only the 'add_difference' merge method is supported" + + for key in model_keys: + info = store.get_model(key) + assert isinstance(info, MainConfig) + model_names.append(info.name) assert ( - info["model_format"] == "diffusers" - ), f"{mod} is not a diffusers model. It must be optimized before merging" - assert info["variant"] == "normal", f"{mod} is a {info['variant']} model, which cannot currently be merged" + info.model_format == "diffusers" + ), f"{info.name} ({info.key}) is not a diffusers model. It must be optimized before merging" assert ( - len(model_names) <= 2 or interp == MergeInterpolationMethod.AddDifference - ), "When merging three models, only the 'add_difference' merge method is supported" + info.variant == "normal" + ), f"{info.name} ({info.key}) is a {info.variant} model, which cannot currently be merged" + # pick up the first model's vae - if mod == model_names[0]: - vae = info.get("vae") - model_paths.extend([(config.root_path / info["path"]).as_posix()]) + if key == model_keys[0]: + vae = info.vae + + # tally base models used + base_models.add(info.base_model) + model_paths.extend([(config.models_path / info.path).as_posix()]) + + assert len(base_models) == 1, f"All models to merge must have same base model, but found bases {base_models}" + base_model = base_models.pop() merge_method = None if interp == "weighted_sum" else MergeInterpolationMethod(interp) logger.debug(f"interp = {interp}, merge_method={merge_method}") @@ -124,17 +149,19 @@ def merge_diffusion_models_and_save( dump_path = (dump_path / merged_model_name).as_posix() merged_pipe.save_pretrained(dump_path, safe_serialization=True) - attributes = dict( - path=dump_path, - description=f"Merge of models {', '.join(model_names)}", - model_format="diffusers", - variant=ModelVariantType.Normal.value, - vae=vae, - ) - return self.manager.add_model( - merged_model_name, - base_model=base_model, - model_type=ModelType.Main, - model_attributes=attributes, - clobber=True, + + # register model and get its unique key + installer = ModelInstallService(store=self._store, config=self._config) + key = installer.register_path(dump_path) + + # update model's config + model_config = self._store.get_model(key) + model_config.update( + dict( + name=merged_model_name, + description=f"Merge of models {', '.join(model_names)}", + vae=vae, + ) ) + self._store.update_model(key, model_config) + return model_config diff --git a/invokeai/backend/model_management/models/__init__.py b/invokeai/backend/model_manager/models/__init__.py similarity index 96% rename from invokeai/backend/model_management/models/__init__.py rename to invokeai/backend/model_manager/models/__init__.py index bf4b208395c..6ea890c4c09 100644 --- a/invokeai/backend/model_management/models/__init__.py +++ b/invokeai/backend/model_manager/models/__init__.py @@ -1,22 +1,20 @@ import inspect from enum import Enum -from typing import Literal, get_origin +from typing import Any, Literal, get_origin from pydantic import BaseModel from .base import ( # noqa: F401 BaseModelType, - DuplicateModelException, InvalidModelException, ModelBase, ModelConfigBase, - ModelError, ModelNotFoundException, ModelType, ModelVariantType, SchedulerPredictionType, - SilenceWarnings, SubModelType, + read_checkpoint_meta, ) from .clip_vision import CLIPVisionModel from .controlnet import ControlNetModel # TODO: @@ -97,14 +95,12 @@ # }, } -MODEL_CONFIGS = list() -OPENAPI_MODEL_CONFIGS = list() +MODEL_CONFIGS: Any = list() +OPENAPI_MODEL_CONFIGS: Any = list() class OpenAPIModelInfoBase(BaseModel): - model_name: str - base_model: BaseModelType - model_type: ModelType + key: str for base_model, models in MODEL_CLASSES.items(): diff --git a/invokeai/backend/model_management/models/base.py b/invokeai/backend/model_manager/models/base.py similarity index 90% rename from invokeai/backend/model_management/models/base.py rename to invokeai/backend/model_manager/models/base.py index 6e507735d4c..8cb5a69a3f0 100644 --- a/invokeai/backend/model_management/models/base.py +++ b/invokeai/backend/model_manager/models/base.py @@ -1,13 +1,14 @@ import inspect import json import os +import shutil import sys import typing -import warnings from abc import ABCMeta, abstractmethod from contextlib import suppress from enum import Enum from pathlib import Path +from types import ModuleType from typing import Any, Callable, Dict, Generic, List, Literal, Optional, Type, TypeVar, Union import numpy as np @@ -15,90 +16,40 @@ import safetensors.torch import torch from diffusers import ConfigMixin, DiffusionPipeline -from diffusers import logging as diffusers_logging from onnx import numpy_helper from onnxruntime import InferenceSession, SessionOptions, get_available_providers from picklescan.scanner import scan_file_path -from pydantic import BaseModel, Field -from transformers import logging as transformers_logging +from invokeai.backend.util import GIG, directory_size +from invokeai.backend.util.logging import InvokeAILogger -class DuplicateModelException(Exception): - pass - - -class InvalidModelException(Exception): - pass +from ..config import ( # noqa F401 + BaseModelType, + ModelConfigBase, + ModelFormat, + ModelType, + ModelVariantType, + SchedulerPredictionType, + SubModelType, +) class ModelNotFoundException(Exception): - pass - - -class BaseModelType(str, Enum): - Any = "any" # For models that are not associated with any particular base model. - StableDiffusion1 = "sd-1" - StableDiffusion2 = "sd-2" - StableDiffusionXL = "sdxl" - StableDiffusionXLRefiner = "sdxl-refiner" - # Kandinsky2_1 = "kandinsky-2.1" - - -class ModelType(str, Enum): - ONNX = "onnx" - Main = "main" - Vae = "vae" - Lora = "lora" - ControlNet = "controlnet" # used by model_probe - TextualInversion = "embedding" - IPAdapter = "ip_adapter" - CLIPVision = "clip_vision" - T2IAdapter = "t2i_adapter" - - -class SubModelType(str, Enum): - UNet = "unet" - TextEncoder = "text_encoder" - TextEncoder2 = "text_encoder_2" - Tokenizer = "tokenizer" - Tokenizer2 = "tokenizer_2" - Vae = "vae" - VaeDecoder = "vae_decoder" - VaeEncoder = "vae_encoder" - Scheduler = "scheduler" - SafetyChecker = "safety_checker" - # MoVQ = "movq" - - -class ModelVariantType(str, Enum): - Normal = "normal" - Inpaint = "inpaint" - Depth = "depth" - - -class SchedulerPredictionType(str, Enum): - Epsilon = "epsilon" - VPrediction = "v_prediction" - Sample = "sample" - + """Exception for when a model is not found on the expected path.""" -class ModelError(str, Enum): - NotFound = "not_found" + pass -class ModelConfigBase(BaseModel): - path: str # or Path - description: Optional[str] = Field(None) - model_format: Optional[str] = Field(None) - error: Optional[ModelError] = Field(None) +class InvalidModelException(Exception): + """Exception for when a model is corrupted in some way; for example missing files.""" - class Config: - use_enum_values = True + pass class EmptyConfigLoader(ConfigMixin): @classmethod def load_config(cls, *args, **kwargs): + """Load empty configuration.""" cls.config_name = kwargs.pop("config_name") return super().load_config(*args, **kwargs) @@ -132,7 +83,7 @@ def __init__( self.base_model = base_model self.model_type = model_type - def _hf_definition_to_type(self, subtypes: List[str]) -> Type: + def _hf_definition_to_type(self, subtypes: List[str]) -> Optional[ModuleType]: if len(subtypes) < 2: raise Exception("Invalid subfolder definition!") if all(t is None for t in subtypes): @@ -231,6 +182,15 @@ def get_model( ) -> Any: raise NotImplementedError() + @classmethod + @abstractmethod + def convert_if_required( + cls, + model_config: ModelConfigBase, + output_path: str, + ) -> str: + raise NotImplementedError() + class DiffusersModel(ModelBase): # child_types: Dict[str, Type] @@ -453,22 +413,6 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False): return checkpoint -class SilenceWarnings(object): - def __init__(self): - self.transformers_verbosity = transformers_logging.get_verbosity() - self.diffusers_verbosity = diffusers_logging.get_verbosity() - - def __enter__(self): - transformers_logging.set_verbosity_error() - diffusers_logging.set_verbosity_error() - warnings.simplefilter("ignore") - - def __exit__(self, type, value, traceback): - transformers_logging.set_verbosity(self.transformers_verbosity) - diffusers_logging.set_verbosity(self.diffusers_verbosity) - warnings.simplefilter("default") - - ONNX_WEIGHTS_NAME = "model.onnx" @@ -672,3 +616,34 @@ def from_pretrained( # TODO: session options return cls(model_path, provider=provider) + + +def trim_model_convert_cache(cache_path: Path, max_cache_size: int): + current_size = directory_size(cache_path) + logger = InvokeAILogger.get_logger() + + if current_size <= max_cache_size: + return + + logger.debug( + "Convert cache has gotten too large {(current_size / GIG):4.2f} > {(max_cache_size / GIG):4.2f}G.. Trimming." + ) + + # For this to work, we make the assumption that the directory contains + # either a 'unet/config.json' file, or a 'config.json' file at top level + def by_atime(path: Path) -> float: + for config in ["unet/config.json", "config.json"]: + sentinel = path / config + if sentinel.exists(): + return sentinel.stat().st_atime + return 0.0 + + # sort by last access time - least accessed files will be at the end + lru_models = sorted(cache_path.iterdir(), key=by_atime, reverse=True) + logger.debug(f"cached models in descending atime order: {lru_models}") + while current_size > max_cache_size and len(lru_models) > 0: + next_victim = lru_models.pop() + victim_size = directory_size(next_victim) + logger.debug(f"Removing cached converted model {next_victim} to free {victim_size / GIG} GB") + shutil.rmtree(next_victim) + current_size -= victim_size diff --git a/invokeai/backend/model_management/models/clip_vision.py b/invokeai/backend/model_manager/models/clip_vision.py similarity index 97% rename from invokeai/backend/model_management/models/clip_vision.py rename to invokeai/backend/model_manager/models/clip_vision.py index 2276c6beed1..9e050894105 100644 --- a/invokeai/backend/model_management/models/clip_vision.py +++ b/invokeai/backend/model_manager/models/clip_vision.py @@ -5,7 +5,7 @@ import torch from transformers import CLIPVisionModelWithProjection -from invokeai.backend.model_management.models.base import ( +from invokeai.backend.model_manager.models.base import ( BaseModelType, InvalidModelException, ModelBase, diff --git a/invokeai/backend/model_management/models/controlnet.py b/invokeai/backend/model_manager/models/controlnet.py similarity index 82% rename from invokeai/backend/model_management/models/controlnet.py rename to invokeai/backend/model_manager/models/controlnet.py index 359df91a820..e1cb4cfb60b 100644 --- a/invokeai/backend/model_management/models/controlnet.py +++ b/invokeai/backend/model_manager/models/controlnet.py @@ -8,7 +8,9 @@ import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig +from ..config import ControlNetCheckpointConfig, ControlNetDiffusersConfig from .base import ( + GIG, BaseModelType, EmptyConfigLoader, InvalidModelException, @@ -32,12 +34,11 @@ class ControlNetModel(ModelBase): # model_class: Type # model_size: int - class DiffusersConfig(ModelConfigBase): - model_format: Literal[ControlNetModelFormat.Diffusers] + class DiffusersConfig(ControlNetDiffusersConfig): + model_format: Literal[ControlNetModelFormat.Diffusers] = ControlNetModelFormat.Diffusers - class CheckpointConfig(ModelConfigBase): - model_format: Literal[ControlNetModelFormat.Checkpoint] - config: str + class CheckpointConfig(ControlNetCheckpointConfig): + model_format: Literal[ControlNetModelFormat.Checkpoint] = ControlNetModelFormat.Checkpoint def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert model_type == ModelType.ControlNet @@ -112,27 +113,22 @@ def detect_format(cls, path: str): @classmethod def convert_if_required( cls, - model_path: str, + model_config: ModelConfigBase, output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, ) -> str: - if cls.detect_format(model_path) == ControlNetModelFormat.Checkpoint: + if isinstance(model_config, ControlNetCheckpointConfig): return _convert_controlnet_ckpt_and_cache( - model_path=model_path, - model_config=config.config, + model_config=model_config, output_path=output_path, - base_model=base_model, ) else: - return model_path + return model_config.path def _convert_controlnet_ckpt_and_cache( - model_path: str, + model_config: ControlNetCheckpointConfig, output_path: str, - base_model: BaseModelType, - model_config: ControlNetModel.CheckpointConfig, + max_cache_size: int, ) -> str: """ Convert the controlnet from checkpoint format to diffusers format, @@ -140,7 +136,7 @@ def _convert_controlnet_ckpt_and_cache( file. If already on disk then just returns Path. """ app_config = InvokeAIAppConfig.get_config() - weights = app_config.root_path / model_path + weights = app_config.root_path / model_config.path output_path = Path(output_path) logger.info(f"Converting {weights} to diffusers format") @@ -148,6 +144,11 @@ def _convert_controlnet_ckpt_and_cache( if output_path.exists(): return output_path + # make sufficient size in the cache folder + size_needed = weights.stat().st_size + max_cache_size = (app_config.conversion_cache_size * GIG,) + trim_model_convert_cache(output_path.parent, max_cache_size - size_needed) + # to avoid circular import errors from ..convert_ckpt_to_diffusers import convert_controlnet_to_diffusers diff --git a/invokeai/backend/model_management/models/ip_adapter.py b/invokeai/backend/model_manager/models/ip_adapter.py similarity index 87% rename from invokeai/backend/model_management/models/ip_adapter.py rename to invokeai/backend/model_manager/models/ip_adapter.py index 63694af0c87..6e5cafcb19e 100644 --- a/invokeai/backend/model_management/models/ip_adapter.py +++ b/invokeai/backend/model_manager/models/ip_adapter.py @@ -1,12 +1,11 @@ import os import typing -from enum import Enum from typing import Literal, Optional import torch from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus, build_ip_adapter -from invokeai.backend.model_management.models.base import ( +from invokeai.backend.model_manager.models.base import ( BaseModelType, InvalidModelException, ModelBase, @@ -17,15 +16,12 @@ classproperty, ) - -class IPAdapterModelFormat(str, Enum): - # The custom IP-Adapter model format defined by InvokeAI. - InvokeAI = "invokeai" +from ..config import ModelFormat class IPAdapterModel(ModelBase): class InvokeAIConfig(ModelConfigBase): - model_format: Literal[IPAdapterModelFormat.InvokeAI] + model_format: Literal[ModelFormat.InvokeAI] def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert model_type == ModelType.IPAdapter @@ -42,7 +38,7 @@ def detect_format(cls, path: str) -> str: model_file = os.path.join(path, "ip_adapter.bin") image_encoder_config_file = os.path.join(path, "image_encoder.txt") if os.path.exists(model_file) and os.path.exists(image_encoder_config_file): - return IPAdapterModelFormat.InvokeAI + return ModelFormat.InvokeAI raise InvalidModelException(f"Unexpected IP-Adapter model format: {path}") @@ -80,7 +76,7 @@ def convert_if_required( base_model: BaseModelType, ) -> str: format = cls.detect_format(model_path) - if format == IPAdapterModelFormat.InvokeAI: + if format == ModelFormat.InvokeAI: return model_path else: raise ValueError(f"Unsupported format: '{format}'.") diff --git a/invokeai/backend/model_management/models/lora.py b/invokeai/backend/model_manager/models/lora.py similarity index 98% rename from invokeai/backend/model_management/models/lora.py rename to invokeai/backend/model_manager/models/lora.py index b6f321d60b5..8258b6343e4 100644 --- a/invokeai/backend/model_management/models/lora.py +++ b/invokeai/backend/model_manager/models/lora.py @@ -2,11 +2,12 @@ import os from enum import Enum from pathlib import Path -from typing import Dict, Optional, Union +from typing import Dict, Literal, Optional, Union import torch from safetensors.torch import load_file +from ..config import LoRAConfig from .base import ( BaseModelType, InvalidModelException, @@ -27,8 +28,8 @@ class LoRAModelFormat(str, Enum): class LoRAModel(ModelBase): # model_size: int - class Config(ModelConfigBase): - model_format: LoRAModelFormat # TODO: + class Config(LoRAConfig): + model_format: Literal[LoRAModelFormat.LyCORIS] # TODO: def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert model_type == ModelType.Lora @@ -80,16 +81,14 @@ def detect_format(cls, path: str): @classmethod def convert_if_required( cls, - model_path: str, + model_config: ModelConfigBase, output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, ) -> str: - if cls.detect_format(model_path) == LoRAModelFormat.Diffusers: + if cls.detect_format(model_config.path) == LoRAModelFormat.Diffusers: # TODO: add diffusers lora when it stabilizes a bit raise NotImplementedError("Diffusers lora not supported") else: - return model_path + return model_config.path class LoRALayerBase: diff --git a/invokeai/backend/model_management/models/sdxl.py b/invokeai/backend/model_manager/models/sdxl.py similarity index 73% rename from invokeai/backend/model_management/models/sdxl.py rename to invokeai/backend/model_manager/models/sdxl.py index 41586e35b98..97ff3e0d328 100644 --- a/invokeai/backend/model_management/models/sdxl.py +++ b/invokeai/backend/model_manager/models/sdxl.py @@ -1,14 +1,13 @@ import json import os from enum import Enum -from typing import Literal, Optional +from typing import Literal from omegaconf import OmegaConf -from pydantic import Field +from ..config import MainDiffusersConfig from .base import ( BaseModelType, - DiffusersModel, InvalidModelException, ModelConfigBase, ModelType, @@ -16,6 +15,7 @@ classproperty, read_checkpoint_meta, ) +from .stable_diffusion import StableDiffusionModelBase class StableDiffusionXLModelFormat(str, Enum): @@ -23,18 +23,13 @@ class StableDiffusionXLModelFormat(str, Enum): Diffusers = "diffusers" -class StableDiffusionXLModel(DiffusersModel): +class StableDiffusionXLModel(StableDiffusionModelBase): # TODO: check that configs overwriten properly - class DiffusersConfig(ModelConfigBase): + class DiffusersConfig(MainDiffusersConfig): model_format: Literal[StableDiffusionXLModelFormat.Diffusers] - vae: Optional[str] = Field(None) - variant: ModelVariantType class CheckpointConfig(ModelConfigBase): model_format: Literal[StableDiffusionXLModelFormat.Checkpoint] - vae: Optional[str] = Field(None) - config: str - variant: ModelVariantType def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert base_model in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner} @@ -104,26 +99,3 @@ def detect_format(cls, model_path: str): return StableDiffusionXLModelFormat.Diffusers else: return StableDiffusionXLModelFormat.Checkpoint - - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, - ) -> str: - # The convert script adapted from the diffusers package uses - # strings for the base model type. To avoid making too many - # source code changes, we simply translate here - if isinstance(config, cls.CheckpointConfig): - from invokeai.backend.model_management.models.stable_diffusion import _convert_ckpt_and_cache - - return _convert_ckpt_and_cache( - version=base_model, - model_config=config, - output_path=output_path, - use_safetensors=False, # corrupts sdxl models for some reason - ) - else: - return model_path diff --git a/invokeai/backend/model_management/models/stable_diffusion.py b/invokeai/backend/model_manager/models/stable_diffusion.py similarity index 87% rename from invokeai/backend/model_management/models/stable_diffusion.py rename to invokeai/backend/model_manager/models/stable_diffusion.py index ffce42d9e96..30a2c5d6e96 100644 --- a/invokeai/backend/model_management/models/stable_diffusion.py +++ b/invokeai/backend/model_manager/models/stable_diffusion.py @@ -2,7 +2,7 @@ import os from enum import Enum from pathlib import Path -from typing import Literal, Optional, Union +from typing import Literal, Optional from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline from omegaconf import OmegaConf @@ -11,6 +11,8 @@ import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig +from ..cache import GIG +from ..config import MainCheckpointConfig, MainDiffusersConfig, SilenceWarnings from .base import ( BaseModelType, DiffusersModel, @@ -19,11 +21,10 @@ ModelNotFoundException, ModelType, ModelVariantType, - SilenceWarnings, classproperty, read_checkpoint_meta, + trim_model_convert_cache, ) -from .sdxl import StableDiffusionXLModel class StableDiffusion1ModelFormat(str, Enum): @@ -31,17 +32,31 @@ class StableDiffusion1ModelFormat(str, Enum): Diffusers = "diffusers" -class StableDiffusion1Model(DiffusersModel): - class DiffusersConfig(ModelConfigBase): +class StableDiffusionModelBase(DiffusersModel): + """Base class that defines common class methodsd.""" + + @classmethod + def convert_if_required( + cls, + model_config: ModelConfigBase, + output_path: str, + ) -> str: + if isinstance(model_config, MainCheckpointConfig): + return _convert_ckpt_and_cache( + model_config=model_config, + output_path=output_path, + use_safetensors=False, # corrupts sdxl models for some reason + ) + else: + return model_config.path + + +class StableDiffusion1Model(StableDiffusionModelBase): + class DiffusersConfig(MainDiffusersConfig): model_format: Literal[StableDiffusion1ModelFormat.Diffusers] - vae: Optional[str] = Field(None) - variant: ModelVariantType - class CheckpointConfig(ModelConfigBase): + class CheckpointConfig(MainCheckpointConfig): model_format: Literal[StableDiffusion1ModelFormat.Checkpoint] - vae: Optional[str] = Field(None) - config: str - variant: ModelVariantType def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert base_model == BaseModelType.StableDiffusion1 @@ -115,31 +130,13 @@ def detect_format(cls, model_path: str): raise InvalidModelException(f"Not a valid model: {model_path}") - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, - ) -> str: - if isinstance(config, cls.CheckpointConfig): - return _convert_ckpt_and_cache( - version=BaseModelType.StableDiffusion1, - model_config=config, - load_safety_checker=False, - output_path=output_path, - ) - else: - return model_path - class StableDiffusion2ModelFormat(str, Enum): Checkpoint = "checkpoint" Diffusers = "diffusers" -class StableDiffusion2Model(DiffusersModel): +class StableDiffusion2Model(StableDiffusionModelBase): # TODO: check that configs overwriten properly class DiffusersConfig(ModelConfigBase): model_format: Literal[StableDiffusion2ModelFormat.Diffusers] @@ -226,33 +223,10 @@ def detect_format(cls, model_path: str): raise InvalidModelException(f"Not a valid model: {model_path}") - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, - ) -> str: - if isinstance(config, cls.CheckpointConfig): - return _convert_ckpt_and_cache( - version=BaseModelType.StableDiffusion2, - model_config=config, - output_path=output_path, - ) - else: - return model_path - # TODO: rework -# pass precision - currently defaulting to fp16 def _convert_ckpt_and_cache( - version: BaseModelType, - model_config: Union[ - StableDiffusion1Model.CheckpointConfig, - StableDiffusion2Model.CheckpointConfig, - StableDiffusionXLModel.CheckpointConfig, - ], + model_config: ModelConfigBase, output_path: str, use_save_model: bool = False, **kwargs, @@ -263,17 +237,22 @@ def _convert_ckpt_and_cache( file. If already on disk then just returns Path. """ app_config = InvokeAIAppConfig.get_config() - + version = model_config.base_model.value weights = app_config.models_path / model_config.path config_file = app_config.root_path / model_config.config output_path = Path(output_path) variant = model_config.variant pipeline_class = StableDiffusionInpaintPipeline if variant == "inpaint" else StableDiffusionPipeline + max_cache_size = app_config.conversion_cache_size * GIG # return cached version if it exists if output_path.exists(): return output_path + # make sufficient size in the cache folder + size_needed = weights.stat().st_size + trim_model_convert_cache(output_path.parent, max_cache_size - size_needed) + # to avoid circular import errors from ...util.devices import choose_torch_device, torch_dtype from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers diff --git a/invokeai/backend/model_management/models/stable_diffusion_onnx.py b/invokeai/backend/model_manager/models/stable_diffusion_onnx.py similarity index 93% rename from invokeai/backend/model_management/models/stable_diffusion_onnx.py rename to invokeai/backend/model_manager/models/stable_diffusion_onnx.py index 2d0dd22c43a..085baf0fb70 100644 --- a/invokeai/backend/model_management/models/stable_diffusion_onnx.py +++ b/invokeai/backend/model_manager/models/stable_diffusion_onnx.py @@ -3,6 +3,7 @@ from diffusers import OnnxRuntimeModel +from ..config import ONNXSD1Config, ONNXSD2Config from .base import ( BaseModelType, DiffusersModel, @@ -21,9 +22,8 @@ class StableDiffusionOnnxModelFormat(str, Enum): class ONNXStableDiffusion1Model(DiffusersModel): - class Config(ModelConfigBase): + class Config(ONNXSD1Config): model_format: Literal[StableDiffusionOnnxModelFormat.Onnx] - variant: ModelVariantType def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert base_model == BaseModelType.StableDiffusion1 @@ -72,19 +72,16 @@ def convert_if_required( cls, model_path: str, output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, + # config: ModelConfigBase, # not used? + # base_model: BaseModelType, # not used? ) -> str: return model_path class ONNXStableDiffusion2Model(DiffusersModel): # TODO: check that configs overwriten properly - class Config(ModelConfigBase): + class Config(ONNXSD2Config): model_format: Literal[StableDiffusionOnnxModelFormat.Onnx] - variant: ModelVariantType - prediction_type: SchedulerPredictionType - upcast_attention: bool def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert base_model == BaseModelType.StableDiffusion2 diff --git a/invokeai/backend/model_management/models/t2i_adapter.py b/invokeai/backend/model_manager/models/t2i_adapter.py similarity index 98% rename from invokeai/backend/model_management/models/t2i_adapter.py rename to invokeai/backend/model_manager/models/t2i_adapter.py index 4adb9901f99..995ad8d38f4 100644 --- a/invokeai/backend/model_management/models/t2i_adapter.py +++ b/invokeai/backend/model_manager/models/t2i_adapter.py @@ -5,7 +5,7 @@ import torch from diffusers import T2IAdapter -from invokeai.backend.model_management.models.base import ( +from .base import ( BaseModelType, EmptyConfigLoader, InvalidModelException, diff --git a/invokeai/backend/model_management/models/textual_inversion.py b/invokeai/backend/model_manager/models/textual_inversion.py similarity index 80% rename from invokeai/backend/model_management/models/textual_inversion.py rename to invokeai/backend/model_manager/models/textual_inversion.py index b59e6350450..74c8b47a57f 100644 --- a/invokeai/backend/model_management/models/textual_inversion.py +++ b/invokeai/backend/model_manager/models/textual_inversion.py @@ -1,8 +1,10 @@ import os -from typing import Optional +from typing import Literal, Optional import torch +from ..config import ModelFormat, TextualInversionConfig + # TODO: naming from ..lora import TextualInversionModel as TextualInversionModelRaw from .base import ( @@ -20,8 +22,15 @@ class TextualInversionModel(ModelBase): # model_size: int - class Config(ModelConfigBase): - model_format: None + class FolderConfig(TextualInversionConfig): + """Config for embeddings that are represented as a folder containing learned_embeds.bin.""" + + model_format: Literal[ModelFormat.EmbeddingFolder] + + class FileConfig(TextualInversionConfig): + """Config for embeddings that are contained in safetensors/checkpoint files.""" + + model_format: Literal[ModelFormat.EmbeddingFile] def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert model_type == ModelType.TextualInversion @@ -79,9 +88,7 @@ def detect_format(cls, path: str): @classmethod def convert_if_required( cls, - model_path: str, + model_config: ModelConfigBase, output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, ) -> str: - return model_path + return model_config.path diff --git a/invokeai/backend/model_management/models/vae.py b/invokeai/backend/model_manager/models/vae.py similarity index 84% rename from invokeai/backend/model_management/models/vae.py rename to invokeai/backend/model_manager/models/vae.py index 637160c69b0..9897feb6183 100644 --- a/invokeai/backend/model_management/models/vae.py +++ b/invokeai/backend/model_manager/models/vae.py @@ -1,7 +1,7 @@ import os from enum import Enum from pathlib import Path -from typing import Optional +from typing import Literal, Optional import safetensors import torch @@ -9,7 +9,9 @@ from invokeai.app.services.config import InvokeAIAppConfig +from ..config import VaeCheckpointConfig, VaeDiffusersConfig from .base import ( + GIG, BaseModelType, EmptyConfigLoader, InvalidModelException, @@ -22,6 +24,7 @@ calc_model_size_by_data, calc_model_size_by_fs, classproperty, + trim_model_convert_cache, ) @@ -34,8 +37,11 @@ class VaeModel(ModelBase): # vae_class: Type # model_size: int - class Config(ModelConfigBase): - model_format: VaeModelFormat + class DiffusersConfig(VaeDiffusersConfig): + model_format: Literal[VaeModelFormat.Diffusers] = VaeModelFormat.Diffusers + + class CheckpointConfig(VaeCheckpointConfig): + model_format: Literal[VaeModelFormat.Checkpoint] = VaeModelFormat.Checkpoint def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert model_type == ModelType.Vae @@ -97,28 +103,22 @@ def detect_format(cls, path: str): @classmethod def convert_if_required( cls, - model_path: str, + model_config: ModelConfigBase, output_path: str, - config: ModelConfigBase, # empty config or config of parent model - base_model: BaseModelType, ) -> str: - if cls.detect_format(model_path) == VaeModelFormat.Checkpoint: + if isinstance(model_config, VaeCheckpointConfig): return _convert_vae_ckpt_and_cache( - weights_path=model_path, + model_config=model_config, output_path=output_path, - base_model=base_model, - model_config=config, ) else: - return model_path + return model_config.path -# TODO: rework def _convert_vae_ckpt_and_cache( - weights_path: str, - output_path: str, - base_model: BaseModelType, model_config: ModelConfigBase, + output_path: str, + max_cache_size: int, ) -> str: """ Convert the VAE indicated in mconfig into a diffusers AutoencoderKL @@ -126,7 +126,7 @@ def _convert_vae_ckpt_and_cache( file. If already on disk then just returns Path. """ app_config = InvokeAIAppConfig.get_config() - weights_path = app_config.root_dir / weights_path + weights_path = app_config.root_dir / model_config.path output_path = Path(output_path) """ @@ -148,6 +148,12 @@ def _convert_vae_ckpt_and_cache( if output_path.exists(): return output_path + # make sufficient size in the cache folder + size_needed = weights_path.stat().st_size + max_cache_size = (app_config.conversion_cache_size * GIG,) + trim_model_convert_cache(output_path.parent, max_cache_size - size_needed) + + base_model = model_config.base_model if base_model in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}: from .stable_diffusion import _select_ckpt_config diff --git a/invokeai/backend/model_management/model_probe.py b/invokeai/backend/model_manager/probe.py similarity index 60% rename from invokeai/backend/model_management/model_probe.py rename to invokeai/backend/model_manager/probe.py index 19d64b035fb..8be0233def1 100644 --- a/invokeai/backend/model_management/model_probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -1,47 +1,89 @@ +# Copyright (c) 2023 Lincoln Stein and the InvokeAI Team +""" +Return descriptive information on Stable Diffusion models. + +Module for probing a Stable Diffusion model and returning +its base type, model type, format and variant. +""" + import json import re -from dataclasses import dataclass +from abc import ABC, abstractmethod from pathlib import Path -from typing import Callable, Dict, Literal, Optional, Union +from typing import Callable, Dict, Optional, Type import safetensors.torch import torch -from diffusers import ConfigMixin, ModelMixin from picklescan.scanner import scan_file_path +from pydantic import BaseModel + +from .config import BaseModelType, ModelFormat, ModelType, ModelVariantType, SchedulerPredictionType +from .hash import FastModelHash +from .util import lora_token_vector_length, read_checkpoint_meta + -from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat +class InvalidModelException(Exception): + """Raised when an invalid model is encountered.""" -from .models import ( - BaseModelType, - InvalidModelException, - ModelType, - ModelVariantType, - SchedulerPredictionType, - SilenceWarnings, -) -from .models.base import read_checkpoint_meta -from .util import lora_token_vector_length +class ModelProbeInfo(BaseModel): + """Fields describing a probed model.""" -@dataclass -class ModelProbeInfo(object): model_type: ModelType base_type: BaseModelType - variant_type: ModelVariantType - prediction_type: SchedulerPredictionType - upcast_attention: bool - format: Literal["diffusers", "checkpoint", "lycoris", "olive", "onnx"] - image_size: int + format: ModelFormat + hash: str + variant_type: ModelVariantType = ModelVariantType("normal") + prediction_type: Optional[SchedulerPredictionType] = SchedulerPredictionType("v_prediction") + upcast_attention: Optional[bool] = False + image_size: Optional[int] = None -class ProbeBase(object): - """forward declaration""" +class ModelProbeBase(ABC): + """Class to probe a checkpoint, safetensors or diffusers folder.""" + + @classmethod + @abstractmethod + def probe( + cls, + model: Path, + prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None, + ) -> Optional[ModelProbeInfo]: + """ + Probe model located at path and return ModelProbeInfo object. + + :param model: Path to a model checkpoint or folder. + :param prediction_type_helper: An optional Callable that takes the model path + and returns the SchedulerPredictionType. + """ + pass + + +class ProbeBase(ABC): + """Base model for probing checkpoint and diffusers-style models.""" + + @abstractmethod + def get_base_type(self) -> Optional[BaseModelType]: + """Return the BaseModelType for the model.""" + pass + + def get_variant_type(self) -> ModelVariantType: + """Return the ModelVariantType for the model.""" + pass + + def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]: + """Return the SchedulerPredictionType for the model.""" + pass + + def get_format(self) -> str: + """Return the format for the model.""" + pass - pass +class ModelProbe(ModelProbeBase): + """Class to probe a checkpoint, safetensors or diffusers folder.""" -class ModelProbe(object): - PROBES = { + PROBES: Dict[str, dict] = { "diffusers": {}, "checkpoint": {}, "onnx": {}, @@ -52,7 +94,6 @@ class ModelProbe(object): "StableDiffusionInpaintPipeline": ModelType.Main, "StableDiffusionXLPipeline": ModelType.Main, "StableDiffusionXLImg2ImgPipeline": ModelType.Main, - "StableDiffusionXLInpaintPipeline": ModelType.Main, "AutoencoderKL": ModelType.Vae, "AutoencoderTiny": ModelType.Vae, "ControlNetModel": ModelType.ControlNet, @@ -61,58 +102,46 @@ class ModelProbe(object): } @classmethod - def register_probe( - cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: ProbeBase - ): - cls.PROBES[format][model_type] = probe_class + def register_probe(cls, format: ModelFormat, model_type: ModelType, probe_class: Type[ProbeBase]): + """ + Register a probe subclass to use when interrogating a model. - @classmethod - def heuristic_probe( - cls, - model: Union[Dict, ModelMixin, Path], - prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None, - ) -> ModelProbeInfo: - if isinstance(model, Path): - return cls.probe(model_path=model, prediction_type_helper=prediction_type_helper) - elif isinstance(model, (dict, ModelMixin, ConfigMixin)): - return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper) - else: - raise InvalidModelException("model parameter {model} is neither a Path, nor a model") + :param format: The ModelFormat of the model to be probed. + :param model_type: The ModelType of the model to be probed. + :param probe_class: The class of the prober (inherits from ProbeBase). + """ + cls.PROBES[format][model_type] = probe_class @classmethod def probe( cls, model_path: Path, - model: Optional[Union[Dict, ModelMixin]] = None, prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None, ) -> ModelProbeInfo: - """ - Probe the model at model_path and return sufficient information about it - to place it somewhere in the models directory hierarchy. If the model is - already loaded into memory, you may provide it as model in order to avoid - opening it a second time. The prediction_type_helper callable is a function that receives - the path to the model and returns the SchedulerPredictionType. - """ - if model_path: - format_type = "diffusers" if model_path.is_dir() else "checkpoint" - else: - format_type = "diffusers" if isinstance(model, (ConfigMixin, ModelMixin)) else "checkpoint" - model_info = None + """Probe model.""" try: model_type = ( - cls.get_model_type_from_folder(model_path, model) - if format_type == "diffusers" - else cls.get_model_type_from_checkpoint(model_path, model) + cls.get_model_type_from_folder(model_path) + if model_path.is_dir() + else cls.get_model_type_from_checkpoint(model_path) + ) + format_type = ( + "onnx" if model_type == ModelType.ONNX else "diffusers" if model_path.is_dir() else "checkpoint" ) - format_type = "onnx" if model_type == ModelType.ONNX else format_type + probe_class = cls.PROBES[format_type].get(model_type) + if not probe_class: - return None - probe = probe_class(model_path, model, prediction_type_helper) + raise InvalidModelException(f"Unable to determine model type for {model_path}") + + probe = probe_class(model_path, prediction_type_helper) + base_type = probe.get_base_type() variant_type = probe.get_variant_type() prediction_type = probe.get_scheduler_prediction_type() format = probe.get_format() + hash = FastModelHash.hash(model_path) + model_info = ModelProbeInfo( model_type=model_type, base_type=base_type, @@ -123,33 +152,35 @@ def probe( and prediction_type == SchedulerPredictionType.VPrediction ), format=format, - image_size=( - 1024 - if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner}) - else ( - 768 - if ( - base_type == BaseModelType.StableDiffusion2 - and prediction_type == SchedulerPredictionType.VPrediction - ) - else 512 - ) - ), + hash=hash, + image_size=1024 + if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner}) + else 768 + if ( + base_type == BaseModelType.StableDiffusion2 + and prediction_type == SchedulerPredictionType.VPrediction + ) + else 512, ) except Exception: - raise + raise InvalidModelException(f"Unable to determine model type for {model_path}") return model_info @classmethod - def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict) -> ModelType: + def get_model_type_from_checkpoint(cls, model_path: Path) -> Optional[ModelType]: + """ + Scan a checkpoint model and return its ModelType. + + :param model_path: path to the model checkpoint/safetensors file + """ if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"): return None if model_path.name == "learned_embeds.bin": return ModelType.TextualInversion - ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True) + ckpt = read_checkpoint_meta(model_path, scan=True) ckpt = ckpt.get("state_dict", ckpt) for key in ckpt.keys(): @@ -174,39 +205,37 @@ def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict) -> M raise InvalidModelException(f"Unable to determine model type for {model_path}") @classmethod - def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin) -> ModelType: + def get_model_type_from_folder(cls, folder_path: Path) -> Optional[ModelType]: """ Get the model type of a hugging-face style folder. + + :param folder_path: Path to model folder. """ class_name = None - error_hint = None - if model: - class_name = model.__class__.__name__ - else: - if (folder_path / "unet/model.onnx").exists(): - return ModelType.ONNX - if (folder_path / "learned_embeds.bin").exists(): - return ModelType.TextualInversion - if (folder_path / "pytorch_lora_weights.bin").exists(): - return ModelType.Lora - if (folder_path / "image_encoder.txt").exists(): - return ModelType.IPAdapter - - i = folder_path / "model_index.json" - c = folder_path / "config.json" - config_path = i if i.exists() else c if c.exists() else None - - if config_path: - with open(config_path, "r") as file: - conf = json.load(file) - if "_class_name" in conf: - class_name = conf["_class_name"] - elif "architectures" in conf: - class_name = conf["architectures"][0] - else: - class_name = None + if (folder_path / "unet/model.onnx").exists(): + return ModelType.ONNX + if (folder_path / "learned_embeds.bin").exists(): + return ModelType.TextualInversion + if (folder_path / "pytorch_lora_weights.bin").exists(): + return ModelType.Lora + if (folder_path / "image_encoder.txt").exists(): + return ModelType.IPAdapter + + i = folder_path / "model_index.json" + c = folder_path / "config.json" + config_path = i if i.exists() else c if c.exists() else None + + if config_path: + with open(config_path, "r") as file: + conf = json.load(file) + if "_class_name" in conf: + class_name = conf["_class_name"] + elif "architectures" in conf: + class_name = conf["architectures"][0] else: - error_hint = f"No model_index.json or config.json found in {folder_path}." + class_name = None + else: + error_hint = f"No model_index.json or config.json found in {folder_path}." if class_name and (type := cls.CLASS2TYPE.get(class_name)): return type @@ -219,59 +248,52 @@ def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin) -> Mod ) @classmethod - def _scan_and_load_checkpoint(cls, model_path: Path) -> dict: - with SilenceWarnings(): - if model_path.suffix.endswith((".ckpt", ".pt", ".bin")): - cls._scan_model(model_path, model_path) - return torch.load(model_path) - else: - return safetensors.torch.load_file(model_path) + def _scan_and_load_checkpoint(cls, model: Path) -> dict: + if model.suffix.endswith((".ckpt", ".pt", ".bin")): + cls._scan_model(model) + return torch.load(model) + else: + return safetensors.torch.load_file(model) @classmethod - def _scan_model(cls, model_name, checkpoint): + def _scan_model(cls, model: Path): """ - Apply picklescanner to the indicated checkpoint and issue a warning - and option to exit if an infected file is identified. + Scan a model for malicious code. + + :param model: Path to the model to be scanned + Raises an Exception if unsafe code is found. """ # scan model - scan_result = scan_file_path(checkpoint) + scan_result = scan_file_path(model) if scan_result.infected_files != 0: - raise "The model {model_name} is potentially infected by malware. Aborting import." + raise InvalidModelException("The model {model_name} is potentially infected by malware. Aborting import.") # ##################################################3 # Checkpoint probing # ##################################################3 -class ProbeBase(object): - def get_base_type(self) -> BaseModelType: - pass - - def get_variant_type(self) -> ModelVariantType: - pass - - def get_scheduler_prediction_type(self) -> SchedulerPredictionType: - pass - - def get_format(self) -> str: - pass class CheckpointProbeBase(ProbeBase): - def __init__( - self, checkpoint_path: Path, checkpoint: dict, helper: Callable[[Path], SchedulerPredictionType] = None - ) -> BaseModelType: - self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path) + """Base class for probing checkpoint-style models.""" + + def __init__(self, checkpoint_path: Path, helper: Optional[Callable[[Path], SchedulerPredictionType]] = None): + """Initialize the CheckpointProbeBase object.""" self.checkpoint_path = checkpoint_path + self.checkpoint = ModelProbe._scan_and_load_checkpoint(checkpoint_path) self.helper = helper - def get_base_type(self) -> BaseModelType: + def get_base_type(self) -> Optional[BaseModelType]: + """Return the BaseModelType of a checkpoint-style model.""" pass def get_format(self) -> str: + """Return the format of a checkpoint-style model.""" return "checkpoint" def get_variant_type(self) -> ModelVariantType: - model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path, self.checkpoint) + """Return the ModelVariantType of a checkpoint-style model.""" + model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path) if model_type != ModelType.Main: return ModelVariantType.Normal state_dict = self.checkpoint.get("state_dict") or self.checkpoint @@ -289,7 +311,10 @@ def get_variant_type(self) -> ModelVariantType: class PipelineCheckpointProbe(CheckpointProbeBase): + """Probe a checkpoint-style main model.""" + def get_base_type(self) -> BaseModelType: + """Return the ModelBaseType for the checkpoint-style main model.""" checkpoint = self.checkpoint state_dict = self.checkpoint.get("state_dict") or checkpoint key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" @@ -338,16 +363,23 @@ def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]: class VaeCheckpointProbe(CheckpointProbeBase): + """Probe a Checkpoint-style VAE model.""" + def get_base_type(self) -> BaseModelType: + """Return the BaseModelType of the VAE model.""" # I can't find any standalone 2.X VAEs to test with! return BaseModelType.StableDiffusion1 class LoRACheckpointProbe(CheckpointProbeBase): + """Probe for LoRA Checkpoint Files.""" + def get_format(self) -> str: + """Return the format of the LoRA.""" return "lycoris" def get_base_type(self) -> BaseModelType: + """Return the BaseModelType of the LoRA.""" checkpoint = self.checkpoint token_vector_length = lora_token_vector_length(checkpoint) @@ -358,14 +390,18 @@ def get_base_type(self) -> BaseModelType: elif token_vector_length == 2048: return BaseModelType.StableDiffusionXL else: - raise InvalidModelException(f"Unknown LoRA type: {self.checkpoint_path}") + raise InvalidModelException(f"Unsupported LoRA type: {self.checkpoint_path}") class TextualInversionCheckpointProbe(CheckpointProbeBase): + """TextualInversion checkpoint prober.""" + def get_format(self) -> str: - return None + """Return the format of a TextualInversion emedding.""" + return ModelFormat.EmbeddingFile def get_base_type(self) -> BaseModelType: + """Return BaseModelType of the checkpoint model.""" checkpoint = self.checkpoint if "string_to_token" in checkpoint: token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1] @@ -377,12 +413,14 @@ def get_base_type(self) -> BaseModelType: return BaseModelType.StableDiffusion1 elif token_dim == 1024: return BaseModelType.StableDiffusion2 - else: - return None + raise InvalidModelException("Unknown base model for {self.checkpoint_path}") class ControlNetCheckpointProbe(CheckpointProbeBase): + """Probe checkpoint-based ControlNet models.""" + def get_base_type(self) -> BaseModelType: + """Return the BaseModelType of the model.""" checkpoint = self.checkpoint for key_name in ( "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", @@ -394,18 +432,22 @@ def get_base_type(self) -> BaseModelType: return BaseModelType.StableDiffusion1 elif checkpoint[key_name].shape[-1] == 1024: return BaseModelType.StableDiffusion2 - elif self.checkpoint_path and self.helper: - return self.helper(self.checkpoint_path) raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}") class IPAdapterCheckpointProbe(CheckpointProbeBase): + """Probe IP adapter models.""" + def get_base_type(self) -> BaseModelType: + """Probe base type.""" raise NotImplementedError() class CLIPVisionCheckpointProbe(CheckpointProbeBase): + """Probe ClipVision adapter models.""" + def get_base_type(self) -> BaseModelType: + """Probe base type.""" raise NotImplementedError() @@ -418,24 +460,33 @@ def get_base_type(self) -> BaseModelType: # classes for probing folders ####################################################### class FolderProbeBase(ProbeBase): - def __init__(self, folder_path: Path, model: ModelMixin = None, helper: Callable = None): # not used - self.model = model + """Class for probing folder-based models.""" + + def __init__(self, folder_path: Path, helper: Optional[Callable] = None): # not used + """ + Initialize the folder prober. + + :param model: Path to the model to be probed. + :param helper: Callable for returning the SchedulerPredictionType (unused). + """ self.folder_path = folder_path def get_variant_type(self) -> ModelVariantType: + """Return the model's variant type.""" return ModelVariantType.Normal def get_format(self) -> str: + """Return the model's format.""" return "diffusers" class PipelineFolderProbe(FolderProbeBase): + """Probe a pipeline (main) folder.""" + def get_base_type(self) -> BaseModelType: - if self.model: - unet_conf = self.model.unet.config - else: - with open(self.folder_path / "unet" / "config.json", "r") as file: - unet_conf = json.load(file) + """Return the BaseModelType of a pipeline folder.""" + with open(self.folder_path / "unet" / "config.json", "r") as file: + unet_conf = json.load(file) if unet_conf["cross_attention_dim"] == 768: return BaseModelType.StableDiffusion1 elif unet_conf["cross_attention_dim"] == 1024: @@ -448,29 +499,21 @@ def get_base_type(self) -> BaseModelType: raise InvalidModelException(f"Unknown base model for {self.folder_path}") def get_scheduler_prediction_type(self) -> SchedulerPredictionType: - if self.model: - scheduler_conf = self.model.scheduler.config - else: - with open(self.folder_path / "scheduler" / "scheduler_config.json", "r") as file: - scheduler_conf = json.load(file) - if scheduler_conf["prediction_type"] == "v_prediction": - return SchedulerPredictionType.VPrediction - elif scheduler_conf["prediction_type"] == "epsilon": - return SchedulerPredictionType.Epsilon - else: - return None + """Return the SchedulerPredictionType of a diffusers-style sd-2 model.""" + with open(self.folder_path / "scheduler" / "scheduler_config.json", "r") as file: + scheduler_conf = json.load(file) + prediction_type = scheduler_conf.get("prediction_type", "epsilon") + return SchedulerPredictionType(prediction_type) def get_variant_type(self) -> ModelVariantType: + """Return the ModelVariantType for diffusers-style main models.""" # This only works for pipelines! Any kind of # exception results in our returning the # "normal" variant type try: - if self.model: - conf = self.model.unet.config - else: - config_file = self.folder_path / "unet" / "config.json" - with open(config_file, "r") as file: - conf = json.load(file) + config_file = self.folder_path / "unet" / "config.json" + with open(config_file, "r") as file: + conf = json.load(file) in_channels = conf["in_channels"] if in_channels == 9: @@ -485,7 +528,10 @@ def get_variant_type(self) -> ModelVariantType: class VaeFolderProbe(FolderProbeBase): + """Class for probing folder-style models.""" + def get_base_type(self) -> BaseModelType: + """Get base type of model.""" if self._config_looks_like_sdxl(): return BaseModelType.StableDiffusionXL elif self._name_looks_like_sdxl(): @@ -515,30 +561,41 @@ def _guess_name(self) -> str: class TextualInversionFolderProbe(FolderProbeBase): + """Probe a HuggingFace-style TextualInversion folder.""" + def get_format(self) -> str: - return None + """Return the format of the TextualInversion.""" + return ModelFormat.EmbeddingFolder def get_base_type(self) -> BaseModelType: + """Return the ModelBaseType of the HuggingFace-style Textual Inversion Folder.""" path = self.folder_path / "learned_embeds.bin" if not path.exists(): - return None - checkpoint = ModelProbe._scan_and_load_checkpoint(path) - return TextualInversionCheckpointProbe(None, checkpoint=checkpoint).get_base_type() + raise InvalidModelException("This textual inversion folder does not contain a learned_embeds.bin file.") + return TextualInversionCheckpointProbe(path).get_base_type() class ONNXFolderProbe(FolderProbeBase): + """Probe an ONNX-format folder.""" + def get_format(self) -> str: + """Return the format of the folder (always "onnx").""" return "onnx" def get_base_type(self) -> BaseModelType: + """Return the BaseModelType of the ONNX folder.""" return BaseModelType.StableDiffusion1 def get_variant_type(self) -> ModelVariantType: + """Return the ModelVariantType of the ONNX folder.""" return ModelVariantType.Normal class ControlNetFolderProbe(FolderProbeBase): + """Probe a ControlNet model folder.""" + def get_base_type(self) -> BaseModelType: + """Return the BaseModelType of a ControlNet model folder.""" config_file = self.folder_path / "config.json" if not config_file.exists(): raise InvalidModelException(f"Cannot determine base type for {self.folder_path}") @@ -549,13 +606,11 @@ def get_base_type(self) -> BaseModelType: base_model = ( BaseModelType.StableDiffusion1 if dimension == 768 - else ( - BaseModelType.StableDiffusion2 - if dimension == 1024 - else BaseModelType.StableDiffusionXL - if dimension == 2048 - else None - ) + else BaseModelType.StableDiffusion2 + if dimension == 1024 + else BaseModelType.StableDiffusionXL + if dimension == 2048 + else None ) if not base_model: raise InvalidModelException(f"Unable to determine model base for {self.folder_path}") @@ -563,7 +618,10 @@ def get_base_type(self) -> BaseModelType: class LoRAFolderProbe(FolderProbeBase): + """Probe a LoRA model folder.""" + def get_base_type(self) -> BaseModelType: + """Get the ModelBaseType of a LoRA model folder.""" model_file = None for suffix in ["safetensors", "bin"]: base_file = self.folder_path / f"pytorch_lora_weights.{suffix}" @@ -572,14 +630,18 @@ def get_base_type(self) -> BaseModelType: break if not model_file: raise InvalidModelException("Unknown LoRA format encountered") - return LoRACheckpointProbe(model_file, None).get_base_type() + return LoRACheckpointProbe(model_file).get_base_type() class IPAdapterFolderProbe(FolderProbeBase): + """Class for probing IP-Adapter models.""" + def get_format(self) -> str: - return IPAdapterModelFormat.InvokeAI.value + """Get format of ip adapter.""" + return ModelFormat.InvokeAI.value def get_base_type(self) -> BaseModelType: + """Get base type of ip adapter.""" model_file = self.folder_path / "ip_adapter.bin" if not model_file.exists(): raise InvalidModelException("Unknown IP-Adapter model format.") @@ -597,7 +659,10 @@ def get_base_type(self) -> BaseModelType: class CLIPVisionFolderProbe(FolderProbeBase): + """Probe for folder-based CLIPVision models.""" + def get_base_type(self) -> BaseModelType: + """Get base type.""" return BaseModelType.Any @@ -622,22 +687,25 @@ def get_base_type(self) -> BaseModelType: ############## register probe classes ###### -ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe) - -ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe) - -ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe) +diffusers = ModelFormat("diffusers") +checkpoint = ModelFormat("checkpoint") + +ModelProbe.register_probe(diffusers, ModelType.Main, PipelineFolderProbe) +ModelProbe.register_probe(diffusers, ModelType.Vae, VaeFolderProbe) +ModelProbe.register_probe(diffusers, ModelType.Lora, LoRAFolderProbe) +ModelProbe.register_probe(diffusers, ModelType.TextualInversion, TextualInversionFolderProbe) +ModelProbe.register_probe(diffusers, ModelType.ControlNet, ControlNetFolderProbe) +ModelProbe.register_probe(diffusers, ModelType.IPAdapter, IPAdapterFolderProbe) +ModelProbe.register_probe(diffusers, ModelType.CLIPVision, CLIPVisionFolderProbe) +ModelProbe.register_probe(diffusers, ModelType.T2IAdapter, T2IAdapterFolderProbe) + +ModelProbe.register_probe(checkpoint, ModelType.Main, PipelineCheckpointProbe) +ModelProbe.register_probe(checkpoint, ModelType.Vae, VaeCheckpointProbe) +ModelProbe.register_probe(checkpoint, ModelType.Lora, LoRACheckpointProbe) +ModelProbe.register_probe(checkpoint, ModelType.TextualInversion, TextualInversionCheckpointProbe) +ModelProbe.register_probe(checkpoint, ModelType.ControlNet, ControlNetCheckpointProbe) +ModelProbe.register_probe(checkpoint, ModelType.IPAdapter, IPAdapterCheckpointProbe) +ModelProbe.register_probe(checkpoint, ModelType.CLIPVision, CLIPVisionCheckpointProbe) +ModelProbe.register_probe(checkpoint, ModelType.T2IAdapter, T2IAdapterCheckpointProbe) + +ModelProbe.register_probe(ModelFormat("onnx"), ModelType.ONNX, ONNXFolderProbe) diff --git a/invokeai/backend/model_management/seamless.py b/invokeai/backend/model_manager/seamless.py similarity index 100% rename from invokeai/backend/model_management/seamless.py rename to invokeai/backend/model_manager/seamless.py diff --git a/invokeai/backend/model_manager/search.py b/invokeai/backend/model_manager/search.py new file mode 100644 index 00000000000..4314249cbbd --- /dev/null +++ b/invokeai/backend/model_manager/search.py @@ -0,0 +1,198 @@ +# Copyright 2023, Lincoln D. Stein and the InvokeAI Team +""" +Abstract base class and implementation for recursive directory search for models. + +Example usage: +``` + from invokeai.backend.model_manager import ModelSearch, ModelProbe + + def find_main_models(model: Path) -> bool: + info = ModelProbe.probe(model) + if info.model_type == 'main' and info.base_type == 'sd-1': + return True + else: + return False + + search = ModelSearch(on_model_found=report_it) + found = search.search('/tmp/models') + print(found) # list of matching model paths + print(search.stats) # search stats +``` +""" + +import os +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Callable, Optional, Set, Union + +from pydantic import BaseModel, Field + +from invokeai.backend.util import InvokeAILogger, Logger + +default_logger = InvokeAILogger.get_logger() + + +class SearchStats(BaseModel): + items_scanned: int = 0 + models_found: int = 0 + models_filtered: int = 0 + + +class ModelSearchBase(ABC, BaseModel): + """ + Abstract directory traversal model search class + + Usage: + search = ModelSearchBase( + on_search_started = search_started_callback, + on_search_completed = search_completed_callback, + on_model_found = model_found_callback, + ) + models_found = search.search('/path/to/directory') + """ + + # fmt: off + on_search_started : Optional[Callable[[Path], None]] = Field(default=None, description="Called just before the search starts.") # noqa E221 + on_model_found : Optional[Callable[[Path], bool]] = Field(default=None, description="Called when a model is found.") # noqa E221 + on_search_completed : Optional[Callable[[Set[Path]], None]] = Field(default=None, description="Called when search is complete.") # noqa E221 + stats : SearchStats = Field(default_factory=SearchStats, description="Summary statistics after search") # noqa E221 + logger : Logger = Field(default=default_logger, description="Logger instance.") # noqa E221 + # fmt: on + + class Config: + underscore_attrs_are_private = True + arbitrary_types_allowed = True + + @abstractmethod + def search_started(self): + """ + Called before the scan starts. + + Passes the root search directory to the Callable `on_search_started`. + """ + pass + + @abstractmethod + def model_found(self, model: Path): + """ + Called when a model is found during search. + + :param model: Model to process - could be a directory or checkpoint. + + Passes the model's Path to the Callable `on_model_found`. + This Callable receives the path to the model and returns a boolean + to indicate whether the model should be returned in the search + results. + """ + pass + + @abstractmethod + def search_completed(self): + """ + Called before the scan starts. + + Passes the Set of found model Paths to the Callable `on_search_completed`. + """ + pass + + @abstractmethod + def search(self, directory: Union[Path, str]) -> Set[Path]: + """ + Recursively search for models in `directory` and return a set of model paths. + + If provided, the `on_search_started`, `on_model_found` and `on_search_completed` + Callables will be invoked during the search. + """ + pass + + +class ModelSearch(ModelSearchBase): + """ + Implementation of ModelSearch with callbacks. + Usage: + search = ModelSearch() + search.model_found = lambda path : 'anime' in path.as_posix() + found = search.list_models(['/tmp/models1','/tmp/models2']) + # returns all models that have 'anime' in the path + """ + + _directory: Path = Field(default=None) + _models_found: Set[Path] = Field(default=None) + _scanned_dirs: Set[Path] = Field(default=None) + _pruned_paths: Set[Path] = Field(default=None) + + def search_started(self): + self._models_found = set() + self._scanned_dirs = set() + self._pruned_paths = set() + if self.on_search_started: + self.on_search_started(self._directory) + + def model_found(self, model: Path): + self.stats.models_found += 1 + if not self.on_model_found: + self.stats.models_filtered += 1 + self._models_found.add(model) + return + if self.on_model_found(model): + self.stats.models_filtered += 1 + self._models_found.add(model) + + def search_completed(self): + if self.on_search_completed: + self.on_search_completed(self._models_found) + + def search(self, directory: Union[Path, str]) -> Set[Path]: + self._directory = Path(directory) + self.stats = SearchStats() # zero out + self.search_started() # This will initialize _models_found to empty + self._walk_directory(directory) + self.search_completed() + return self._models_found + + def _walk_directory(self, path: Union[Path, str]): + for root, dirs, files in os.walk(path, followlinks=True): + # don't descend into directories that start with a "." + # to avoid the Mac .DS_STORE issue. + if str(Path(root).name).startswith("."): + self._pruned_paths.add(Path(root)) + if any([Path(root).is_relative_to(x) for x in self._pruned_paths]): + continue + + self.stats.items_scanned += len(dirs) + len(files) + for d in dirs: + path = Path(root) / d + if path.parent in self._scanned_dirs: + self._scanned_dirs.add(path) + continue + if any( + [ + (path / x).exists() + for x in [ + "config.json", + "model_index.json", + "learned_embeds.bin", + "pytorch_lora_weights.bin", + "image_encoder.txt", + ] + ] + ): + self._scanned_dirs.add(path) + try: + self.model_found(path) + except KeyboardInterrupt: + raise + except Exception as e: + self.logger.warning(str(e)) + + for f in files: + path = Path(root) / f + if path.parent in self._scanned_dirs: + continue + if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}: + try: + self.model_found(path) + except KeyboardInterrupt: + raise + except Exception as e: + self.logger.warning(str(e)) diff --git a/invokeai/backend/model_manager/storage/__init__.py b/invokeai/backend/model_manager/storage/__init__.py new file mode 100644 index 00000000000..dec4d4d4a6b --- /dev/null +++ b/invokeai/backend/model_manager/storage/__init__.py @@ -0,0 +1,13 @@ +"""Initialization file for invokeai.backend.model_manager.storage.""" +import pathlib + +from ..config import AnyModelConfig # noqa F401 +from .base import ( # noqa F401 + ConfigFileVersionMismatchException, + DuplicateModelException, + ModelConfigStore, + UnknownModelException, +) +from .migrate import migrate_models_store # noqa F401 +from .sql import ModelConfigStoreSQL # noqa F401 +from .yaml import ModelConfigStoreYAML # noqa F401 diff --git a/invokeai/backend/model_manager/storage/base.py b/invokeai/backend/model_manager/storage/base.py new file mode 100644 index 00000000000..faa7ecb9b04 --- /dev/null +++ b/invokeai/backend/model_manager/storage/base.py @@ -0,0 +1,166 @@ +# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team +""" +Abstract base class for storing and retrieving model configuration records. +""" + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import List, Optional, Set, Union + +from ..config import AnyModelConfig, BaseModelType, ModelConfigBase, ModelType + +# should match the InvokeAI version when this is first released. +CONFIG_FILE_VERSION = "3.2" + + +class DuplicateModelException(Exception): + """Raised on an attempt to add a model with the same key twice.""" + + +class InvalidModelException(Exception): + """Raised when an invalid model is detected.""" + + +class UnknownModelException(Exception): + """Raised on an attempt to fetch or delete a model with a nonexistent key.""" + + +class ConfigFileVersionMismatchException(Exception): + """Raised on an attempt to open a config with an incompatible version.""" + + +class ModelConfigStore(ABC): + """Abstract base class for storage and retrieval of model configs.""" + + @property + @abstractmethod + def version(self) -> str: + """Return the config file/database schema version.""" + pass + + @abstractmethod + def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> ModelConfigBase: + """ + Add a model to the database. + + :param key: Unique key for the model + :param config: Model configuration record, either a dict with the + required fields or a ModelConfigBase instance. + + Can raise DuplicateModelException and InvalidModelConfigException exceptions. + """ + pass + + @abstractmethod + def del_model(self, key: str) -> None: + """ + Delete a model. + + :param key: Unique key for the model to be deleted + + Can raise an UnknownModelException + """ + pass + + @abstractmethod + def update_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig: + """ + Update the model, returning the updated version. + + :param key: Unique key for the model to be updated + :param config: Model configuration record. Either a dict with the + required fields, or a ModelConfigBase instance. + """ + pass + + @abstractmethod + def get_model(self, key: str) -> AnyModelConfig: + """ + Retrieve the configuration for the indicated model. + + :param key: Key of model config to be fetched. + + Exceptions: UnknownModelException + """ + pass + + @abstractmethod + def exists(self, key: str) -> bool: + """ + Return True if a model with the indicated key exists in the databse. + + :param key: Unique key for the model to be deleted + """ + pass + + @abstractmethod + def search_by_tag(self, tags: Set[str]) -> List[AnyModelConfig]: + """ + Return models containing all of the listed tags. + + :param tags: Set of tags to search on. + """ + pass + + @abstractmethod + def search_by_path( + self, + path: Union[str, Path], + ) -> Optional[AnyModelConfig]: + """Return the model having the indicated path.""" + pass + + @abstractmethod + def search_by_name( + self, + model_name: Optional[str] = None, + base_model: Optional[BaseModelType] = None, + model_type: Optional[ModelType] = None, + ) -> List[AnyModelConfig]: + """ + Return models matching name, base and/or type. + + :param model_name: Filter by name of model (optional) + :param base_model: Filter by base model (optional) + :param model_type: Filter by type of model (optional) + + If none of the optional filters are passed, will return all + models in the database. + """ + pass + + def all_models(self) -> List[AnyModelConfig]: + """Return all the model configs in the database.""" + return self.search_by_name() + + def model_info_by_name(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> ModelConfigBase: + """ + Return information about a single model using its name, base type and model type. + + If there are more than one model that match, raises a DuplicateModelException. + If no model matches, raises an UnknownModelException + """ + model_configs = self.search_by_name(model_name=model_name, base_model=base_model, model_type=model_type) + if len(model_configs) > 1: + raise DuplicateModelException( + "More than one model share the same name and type: {base_model}/{model_type}/{model_name}" + ) + if len(model_configs) == 0: + raise UnknownModelException("No known model with name and type: {base_model}/{model_type}/{model_name}") + return model_configs[0] + + def rename_model( + self, + key: str, + new_name: str, + ) -> ModelConfigBase: + """ + Rename the indicated model. Just a special case of update_model(). + + In some implementations, renaming the model may involve changing where + it is stored on the filesystem. So this is broken out. + + :param key: Model key + :param new_name: New name for model + """ + return self.update_model(key, {"name": new_name}) diff --git a/invokeai/backend/model_manager/storage/migrate.py b/invokeai/backend/model_manager/storage/migrate.py new file mode 100644 index 00000000000..d17efd66e33 --- /dev/null +++ b/invokeai/backend/model_manager/storage/migrate.py @@ -0,0 +1,67 @@ +# Copyright (c) 2023 The InvokeAI Development Team + +import shutil +from pathlib import Path + +from omegaconf import OmegaConf + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend.util.logging import InvokeAILogger + +from ..config import BaseModelType, MainCheckpointConfig, MainConfig, ModelType +from .base import CONFIG_FILE_VERSION + + +def migrate_models_store(config: InvokeAIAppConfig) -> Path: + """Migrate models from v1 models.yaml to v3.2 models.yaml.""" + # avoid circular import + from invokeai.backend.model_manager.install import DuplicateModelException, ModelInstall + from invokeai.backend.model_manager.storage import get_config_store + + app_config = InvokeAIAppConfig.get_config() + logger = InvokeAILogger.get_logger() + old_file: Path = app_config.model_conf_path + new_file: Path = old_file.with_name("models3_2.yaml") + + old_conf = OmegaConf.load(old_file) + store = get_config_store(new_file) + installer = ModelInstall(store=store) + logger.info(f"Migrating old models file at {old_file} to new {CONFIG_FILE_VERSION} format") + + for model_key, stanza in old_conf.items(): + if model_key == "__metadata__": + assert ( + stanza["version"] == "3.0.0" + ), f"This script works on version 3.0.0 yaml files, but your configuration points to a {stanza['version']} version" + continue + + base_type, model_type, model_name = str(model_key).split("/") + new_key = "" + + try: + path = app_config.models_path / stanza["path"] + new_key = installer.register_path(path) + except DuplicateModelException: + # if model already installed, then we just update its info + models = store.search_by_name( + model_name=model_name, base_model=BaseModelType(base_type), model_type=ModelType(model_type) + ) + if len(models) != 1: + continue + new_key = models[0].key + except Exception as excp: + print(str(excp)) + + if new_key != "": + model_info = store.get_model(new_key) + if (vae := stanza.get("vae")) and isinstance(model_info, MainConfig): + model_info.vae = (app_config.models_path / vae).as_posix() + if (model_config := stanza.get("config")) and isinstance(model_info, MainCheckpointConfig): + model_info.config = (app_config.root_path / model_config).as_posix() + model_info.description = stanza.get("description") + store.update_model(new_key, model_info) + + logger.info(f"Original version of models config file saved as {str(old_file) + '.orig'}") + shutil.move(old_file, str(old_file) + ".orig") + shutil.move(new_file, old_file) + return old_file diff --git a/invokeai/backend/model_manager/storage/sql.py b/invokeai/backend/model_manager/storage/sql.py new file mode 100644 index 00000000000..50f4db49abc --- /dev/null +++ b/invokeai/backend/model_manager/storage/sql.py @@ -0,0 +1,468 @@ +# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team +""" +Implementation of ModelConfigStore using a SQLite3 database + +Typical usage: + + from invokeai.backend.model_manager import ModelConfigStoreSQL + store = ModelConfigStoreYAML("./configs/models.yaml") + config = dict( + path='/tmp/pokemon.bin', + name='old name', + base_model='sd-1', + model_type='embedding', + model_format='embedding_file', + author='Anonymous', + tags=['sfw','cartoon'] + ) + + # adding - the key becomes the model's "key" field + store.add_model('key1', config) + + # updating + config.name='new name' + store.update_model('key1', config) + + # checking for existence + if store.exists('key1'): + print("yes") + + # fetching config + new_config = store.get_model('key1') + print(new_config.name, new_config.base_model) + assert new_config.key == 'key1' + + # deleting + store.del_model('key1') + + # searching + configs = store.search_by_tag({'sfw','oss license'}) + configs = store.search_by_name(base_model='sd-2', model_type='main') +""" + +import json +import sqlite3 +import threading +from pathlib import Path +from typing import List, Optional, Set, Union + +from ..config import AnyModelConfig, BaseModelType, ModelConfigBase, ModelConfigFactory, ModelType +from .base import CONFIG_FILE_VERSION, DuplicateModelException, ModelConfigStore, UnknownModelException + + +class ModelConfigStoreSQL(ModelConfigStore): + """Implementation of the ModelConfigStore ABC using a YAML file.""" + + _conn: sqlite3.Connection + _cursor: sqlite3.Cursor + _lock: threading.Lock + + def __init__(self, conn: sqlite3.Connection, lock: threading.Lock): + """ + Initialize a new object from preexisting sqlite3 connection and threading lock objects. + + :param conn: sqlite3 connection object + :param lock: threading Lock object + """ + + super().__init__() + self._conn = conn + # Enable row factory to get rows as dictionaries (must be done before making the cursor!) + self._conn.row_factory = sqlite3.Row + self._cursor = self._conn.cursor() + self._lock = lock + + with self._lock: + # Enable foreign keys + self._conn.execute("PRAGMA foreign_keys = ON;") + self._create_tables() + self._conn.commit() + assert ( + str(self.version) == CONFIG_FILE_VERSION + ), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}" + + def _create_tables(self) -> None: + """Create sqlite3 tables.""" + # model_config table breaks out the fields that are common to all config objects + # and puts class-specific ones in a serialized json object + self._cursor.execute( + """--sql + CREATE TABLE IF NOT EXISTS model_config ( + id TEXT NOT NULL PRIMARY KEY, + -- These 4 fields are enums in python, unrestricted string here + base_model TEXT NOT NULL, + model_type TEXT NOT NULL, + model_name TEXT NOT NULL, + model_path TEXT NOT NULL, + -- Serialized JSON representation of the whole config object, + -- which will contain additional fields from subclasses + config TEXT NOT NULL, + created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + -- Updated via trigger + updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) + ); + """ + ) + + # model_tag table 1:M relation between model key and tag(s) + self._cursor.execute( + """--sql + CREATE TABLE IF NOT EXISTS model_tag ( + id TEXT NOT NULL, + tag_id INTEGER NOT NULL, + FOREIGN KEY(id) REFERENCES model_config(id), + FOREIGN KEY(tag_id) REFERENCES tags(tag_id), + UNIQUE(id,tag_id) + ); + """ + ) + + # tags table + self._cursor.execute( + """--sql + CREATE TABLE IF NOT EXISTS tags ( + tag_id INTEGER NOT NULL PRIMARY KEY, + tag_text TEXT NOT NULL UNIQUE + ); + """ + ) + + # metadata table + self._cursor.execute( + """--sql + CREATE TABLE IF NOT EXISTS model_manager_metadata ( + metadata_key TEXT NOT NULL PRIMARY KEY, + metadata_value TEXT NOT NULL + ); + """ + ) + + # Add trigger for `updated_at`. + self._cursor.execute( + """--sql + CREATE TRIGGER IF NOT EXISTS model_config_updated_at + AFTER UPDATE + ON model_config FOR EACH ROW + BEGIN + UPDATE model_config SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW') + WHERE id = old.id; + END; + """ + ) + + # Add trigger to remove tags when model is deleted + self._cursor.execute( + """--sql + CREATE TRIGGER IF NOT EXISTS model_deleted + AFTER DELETE + ON model_config + BEGIN + DELETE from model_tag WHERE id=old.id; + END; + """ + ) + + # Add our version to the metadata table + self._cursor.execute( + """--sql + INSERT OR IGNORE into model_manager_metadata ( + metadata_key, + metadata_value + ) + VALUES (?,?); + """, + ("version", CONFIG_FILE_VERSION), + ) + + def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase: + """ + Add a model to the database. + + :param key: Unique key for the model + :param config: Model configuration record, either a dict with the + required fields or a ModelConfigBase instance. + + Can raise DuplicateModelException and InvalidModelConfigException exceptions. + """ + record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect. + json_serialized = json.dumps(record.dict()) # and turn it into a json string. + with self._lock: + try: + self._cursor.execute( + """--sql + INSERT INTO model_config ( + id, + base_model, + model_type, + model_name, + model_path, + config + ) + VALUES (?,?,?,?,?,?); + """, + ( + key, + record.base_model, + record.model_type, + record.name, + record.path, + json_serialized, + ), + ) + if record.tags: + self._update_tags(key, record.tags) + self._conn.commit() + + except sqlite3.IntegrityError as e: + self._conn.rollback() + if "UNIQUE constraint failed" in str(e): + raise DuplicateModelException(f"A model with key '{key}' is already installed") from e + else: + raise e + except sqlite3.Error as e: + self._conn.rollback() + raise e + + return self.get_model(key) + + @property + def version(self) -> str: + """Return the version of the database schema.""" + with self._lock: + self._cursor.execute( + """--sql + SELECT metadata_value FROM model_manager_metadata + WHERE metadata_key=?; + """, + ("version",), + ) + rows = self._cursor.fetchone() + if not rows: + raise KeyError("Models database does not have metadata key 'version'") + return rows[0] + + def _update_tags(self, key: str, tags: List[str]) -> None: + """Update tags for model with key.""" + # remove previous tags from this model + self._cursor.execute( + """--sql + DELETE FROM model_tag + WHERE id=?; + """, + (key,), + ) + + # NOTE: isn't there a more elegant way of doing this than one tag + # at a time, with a select to get the tag ID? + for tag in tags: + self._cursor.execute( + """--sql + INSERT OR IGNORE INTO tags ( + tag_text + ) + VALUES (?); + """, + (tag,), + ) + self._cursor.execute( + """--sql + SELECT tag_id + FROM tags + WHERE tag_text = ? + LIMIT 1; + """, + (tag,), + ) + tag_id = self._cursor.fetchone()[0] + self._cursor.execute( + """--sql + INSERT OR IGNORE INTO model_tag ( + id, + tag_id + ) + VALUES (?,?); + """, + (key, tag_id), + ) + + def del_model(self, key: str) -> None: + """ + Delete a model. + + :param key: Unique key for the model to be deleted + + Can raise an UnknownModelException + """ + with self._lock: + try: + self._cursor.execute( + """--sql + DELETE FROM model_config + WHERE id=?; + """, + (key,), + ) + if self._cursor.rowcount == 0: + raise UnknownModelException + self._conn.commit() + except sqlite3.Error as e: + self._conn.rollback() + raise e + + def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase: + """ + Update the model, returning the updated version. + + :param key: Unique key for the model to be updated + :param config: Model configuration record. Either a dict with the + required fields, or a ModelConfigBase instance. + """ + record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect + json_serialized = json.dumps(record.dict()) # and turn it into a json string. + with self._lock: + try: + self._cursor.execute( + """--sql + UPDATE model_config + SET base_model=?, + model_type=?, + model_name=?, + model_path=?, + config=? + WHERE id=?; + """, + (record.base_model, record.model_type, record.name, record.path, json_serialized, key), + ) + if self._cursor.rowcount == 0: + raise UnknownModelException + if record.tags: + self._update_tags(key, record.tags) + self._conn.commit() + except sqlite3.Error as e: + self._conn.rollback() + raise e + + return self.get_model(key) + + def get_model(self, key: str) -> AnyModelConfig: + """ + Retrieve the ModelConfigBase instance for the indicated model. + + :param key: Key of model config to be fetched. + + Exceptions: UnknownModelException + """ + with self._lock: + self._cursor.execute( + """--sql + SELECT config FROM model_config + WHERE id=?; + """, + (key,), + ) + rows = self._cursor.fetchone() + if not rows: + raise UnknownModelException + model = ModelConfigFactory.make_config(json.loads(rows[0])) + return model + + def exists(self, key: str) -> bool: + """ + Return True if a model with the indicated key exists in the databse. + + :param key: Unique key for the model to be deleted + """ + count = 0 + with self._lock: + try: + self._cursor.execute( + """--sql + select count(*) FROM model_config + WHERE id=?; + """, + (key,), + ) + count = self._cursor.fetchone()[0] + except sqlite3.Error as e: + raise e + return count > 0 + + def search_by_tag(self, tags: Set[str]) -> List[AnyModelConfig]: + """Return models containing all of the listed tags.""" + # rather than create a hairy SQL cross-product, we intersect + # tag results in a stepwise fashion at the python level. + results = [] + with self._lock: + try: + matches: Set[str] = set() + for tag in tags: + self._cursor.execute( + """--sql + SELECT a.id FROM model_tag AS a, + tags AS b + WHERE a.tag_id=b.tag_id + AND b.tag_text=?; + """, + (tag,), + ) + model_keys = {x[0] for x in self._cursor.fetchall()} + matches = matches.intersection(model_keys) if len(matches) > 0 else model_keys + if matches: + self._cursor.execute( + f"""--sql + SELECT config FROM model_config + WHERE id IN ({','.join('?' * len(matches))}); + """, + tuple(matches), + ) + results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()] + except sqlite3.Error as e: + raise e + return results + + def search_by_name( + self, + model_name: Optional[str] = None, + base_model: Optional[BaseModelType] = None, + model_type: Optional[ModelType] = None, + ) -> List[AnyModelConfig]: + """ + Return models matching name, base and/or type. + + :param model_name: Filter by name of model (optional) + :param base_model: Filter by base model (optional) + :param model_type: Filter by type of model (optional) + + If none of the optional filters are passed, will return all + models in the database. + """ + results = [] + where_clause = [] + bindings = [] + if model_name: + where_clause.append("model_name=?") + bindings.append(model_name) + if base_model: + where_clause.append("base_model=?") + bindings.append(base_model) + if model_type: + where_clause.append("model_type=?") + bindings.append(model_type) + where = f"WHERE {' AND '.join(where_clause)}" if where_clause else "" + with self._lock: + try: + self._cursor.execute( + f"""--sql + select config FROM model_config + {where}; + """, + tuple(bindings), + ) + results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()] + except sqlite3.Error as e: + raise e + return results + + def search_by_path(self, path: Union[str, Path]) -> Optional[ModelConfigBase]: + """Return the model with the indicated path, or None.""" + raise NotImplementedError("search_by_path not implemented in storage.sql") diff --git a/invokeai/backend/model_manager/storage/yaml.py b/invokeai/backend/model_manager/storage/yaml.py new file mode 100644 index 00000000000..f8b7f51c86c --- /dev/null +++ b/invokeai/backend/model_manager/storage/yaml.py @@ -0,0 +1,239 @@ +# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team +""" +Implementation of ModelConfigStore using a YAML file. + +Typical usage: + + from invokeai.backend.model_manager.storage.yaml import ModelConfigStoreYAML + store = ModelConfigStoreYAML("./configs/models.yaml") + config = dict( + path='/tmp/pokemon.bin', + name='old name', + base_model='sd-1', + model_type='embedding', + model_format='embedding_file', + author='Anonymous', + tags=['sfw','cartoon'] + ) + + # adding - the key becomes the model's "key" field + store.add_model('key1', config) + + # updating + config.name='new name' + store.update_model('key1', config) + + # checking for existence + if store.exists('key1'): + print("yes") + + # fetching config + new_config = store.get_model('key1') + print(new_config.name, new_config.base_model) + assert new_config.key == 'key1' + + # deleting + store.del_model('key1') + + # searching + configs = store.search_by_tag({'sfw','oss license'}) + configs = store.search_by_name(base_model='sd-2', model_type='main') +""" + +import threading +from enum import Enum +from pathlib import Path +from typing import List, Optional, Set, Union + +import yaml +from omegaconf import OmegaConf +from omegaconf.dictconfig import DictConfig + +from ..config import AnyModelConfig, BaseModelType, ModelConfigBase, ModelConfigFactory, ModelType +from .base import ( + CONFIG_FILE_VERSION, + ConfigFileVersionMismatchException, + DuplicateModelException, + ModelConfigStore, + UnknownModelException, +) + + +class ModelConfigStoreYAML(ModelConfigStore): + """Implementation of the ModelConfigStore ABC using a YAML file.""" + + _filename: Path + _config: DictConfig + _lock: threading.RLock + + def __init__(self, config_file: Path): + """Initialize ModelConfigStore object with a .yaml file.""" + super().__init__() + self._filename = Path(config_file).absolute() # don't let chdir mess us up! + self._lock = threading.RLock() + if not self._filename.exists(): + self._initialize_yaml() + config = OmegaConf.load(self._filename) + assert isinstance(config, DictConfig) + self._config = config + if str(self.version) != CONFIG_FILE_VERSION: + raise ConfigFileVersionMismatchException + + def _initialize_yaml(self): + with self._lock: + self._filename.parent.mkdir(parents=True, exist_ok=True) + with open(self._filename, "w") as yaml_file: + yaml_file.write(yaml.dump({"__metadata__": {"version": CONFIG_FILE_VERSION}})) + + def _commit(self): + with self._lock: + newfile = Path(str(self._filename) + ".new") + yaml_str = OmegaConf.to_yaml(self._config) + with open(newfile, "w", encoding="utf-8") as outfile: + outfile.write(yaml_str) + newfile.replace(self._filename) + + @property + def version(self) -> str: + """Return version of this config file/database.""" + return self._config.__metadata__.get("version") + + def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase: + """ + Add a model to the database. + + :param key: Unique key for the model + :param config: Model configuration record, either a dict with the + required fields or a ModelConfigBase instance. + + Can raise DuplicateModelException and InvalidModelConfigException exceptions. + """ + record = ModelConfigFactory.make_config(config, key) # ensure it is a valid config obect + dict_fields = record.dict() # and back to a dict with valid fields + with self._lock: + if key in self._config: + existing_model = self.get_model(key) + raise DuplicateModelException( + f"Can't save {record.name} because a model named '{existing_model.name}' is already stored with the same key '{key}'" + ) + self._config[key] = self._fix_enums(dict_fields) + self._commit() + return self.get_model(key) + + def _fix_enums(self, original: dict) -> dict: + """In python 3.9, omegaconf stores incorrectly stringified enums.""" + fixed_dict = {} + for key, value in original.items(): + fixed_dict[key] = value.value if isinstance(value, Enum) else value + return fixed_dict + + def del_model(self, key: str) -> None: + """ + Delete a model. + + :param key: Unique key for the model to be deleted + + Can raise an UnknownModelException + """ + with self._lock: + if key not in self._config: + raise UnknownModelException(f"Unknown key '{key}' for model config") + self._config.pop(key) + self._commit() + + def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase: + """ + Update the model, returning the updated version. + + :param key: Unique key for the model to be updated + :param config: Model configuration record. Either a dict with the + required fields, or a ModelConfigBase instance. + """ + record = ModelConfigFactory.make_config(config, key) # ensure it is a valid config obect + dict_fields = record.dict() # and back to a dict with valid fields + with self._lock: + if key not in self._config: + raise UnknownModelException(f"Unknown key '{key}' for model config") + self._config[key] = self._fix_enums(dict_fields) + self._commit() + return self.get_model(key) + + def get_model(self, key: str) -> AnyModelConfig: + """ + Retrieve the ModelConfigBase instance for the indicated model. + + :param key: Key of model config to be fetched. + + Exceptions: UnknownModelException + """ + try: + record = self._config[key] + return ModelConfigFactory.make_config(record, key) + except KeyError as e: + raise UnknownModelException(f"Unknown key '{key}' for model config") from e + + def exists(self, key: str) -> bool: + """ + Return True if a model with the indicated key exists in the databse. + + :param key: Unique key for the model to be deleted + """ + return key in self._config + + def search_by_tag(self, tags: Set[str]) -> List[ModelConfigBase]: + """ + Return models containing all of the listed tags. + + :param tags: Set of tags to search on. + """ + results = [] + tags = set(tags) + with self._lock: + for config in self.all_models(): + config_tags = set(config.tags or []) + if tags.difference(config_tags): # not all tags in the model + continue + results.append(config) + return results + + def search_by_name( + self, + model_name: Optional[str] = None, + base_model: Optional[BaseModelType] = None, + model_type: Optional[ModelType] = None, + ) -> List[ModelConfigBase]: + """ + Return models matching name, base and/or type. + + :param model_name: Filter by name of model (optional) + :param base_model: Filter by base model (optional) + :param model_type: Filter by type of model (optional) + + If none of the optional filters are passed, will return all + models in the database. + """ + results: List[ModelConfigBase] = list() + with self._lock: + for key, record in self._config.items(): + if key == "__metadata__": + continue + model = ModelConfigFactory.make_config(record, str(key)) + if model_name and model.name != model_name: + continue + if base_model and model.base_model != base_model: + continue + if model_type and model.model_type != model_type: + continue + results.append(model) + return results + + def search_by_path(self, path: Union[str, Path]) -> Optional[ModelConfigBase]: + """Return the model with the indicated path, or None.""" + with self._lock: + for key, record in self._config.items(): + if key == "__metadata__": + continue + model = ModelConfigFactory.make_config(record, str(key)) + if model.path == path: + return model + return None diff --git a/invokeai/backend/model_manager/util.py b/invokeai/backend/model_manager/util.py new file mode 100644 index 00000000000..34ad6e92d97 --- /dev/null +++ b/invokeai/backend/model_manager/util.py @@ -0,0 +1,162 @@ +# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team +""" +Various utilities used by the model manager. +""" +import json +import warnings +from pathlib import Path +from typing import Optional, Union + +import safetensors +import torch +from diffusers import logging as diffusers_logging +from picklescan.scanner import scan_file_path +from transformers import logging as transformers_logging + + +class SilenceWarnings(object): + """ + Context manager that silences warnings from transformers and diffusers. + + Usage: + with SilenceWarnings(): + do_something_that_generates_warnings() + """ + + def __init__(self): + """Initialize SilenceWarnings context.""" + self.transformers_verbosity = transformers_logging.get_verbosity() + self.diffusers_verbosity = diffusers_logging.get_verbosity() + + def __enter__(self): + """Entry into the context.""" + transformers_logging.set_verbosity_error() + diffusers_logging.set_verbosity_error() + warnings.simplefilter("ignore") + + def __exit__(self, type, value, traceback): + """Exit from the context.""" + transformers_logging.set_verbosity(self.transformers_verbosity) + diffusers_logging.set_verbosity(self.diffusers_verbosity) + warnings.simplefilter("default") + + +def lora_token_vector_length(checkpoint: dict) -> Optional[int]: + """ + Given a checkpoint in memory, return the lora token vector length. + + :param checkpoint: The checkpoint + """ + + def _get_shape_1(key, tensor, checkpoint): + lora_token_vector_length = None + + if "." not in key: + return lora_token_vector_length # wrong key format + model_key, lora_key = key.split(".", 1) + + # check lora/locon + if lora_key == "lora_down.weight": + lora_token_vector_length = tensor.shape[1] + + # check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes) + elif lora_key in ["hada_w1_b", "hada_w2_b"]: + lora_token_vector_length = tensor.shape[1] + + # check lokr (don't worry about lokr_t2 as it used only in 4d shapes) + elif "lokr_" in lora_key: + if model_key + ".lokr_w1" in checkpoint: + _lokr_w1 = checkpoint[model_key + ".lokr_w1"] + elif model_key + "lokr_w1_b" in checkpoint: + _lokr_w1 = checkpoint[model_key + ".lokr_w1_b"] + else: + return lora_token_vector_length # unknown format + + if model_key + ".lokr_w2" in checkpoint: + _lokr_w2 = checkpoint[model_key + ".lokr_w2"] + elif model_key + "lokr_w2_b" in checkpoint: + _lokr_w2 = checkpoint[model_key + ".lokr_w2_b"] + else: + return lora_token_vector_length # unknown format + + lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1] + + elif lora_key == "diff": + lora_token_vector_length = tensor.shape[1] + + # ia3 can be detected only by shape[0] in text encoder + elif lora_key == "weight" and "lora_unet_" not in model_key: + lora_token_vector_length = tensor.shape[0] + + return lora_token_vector_length + + lora_token_vector_length = None + lora_te1_length = None + lora_te2_length = None + for key, tensor in checkpoint.items(): + if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key): + lora_token_vector_length = _get_shape_1(key, tensor, checkpoint) + elif key.startswith("lora_te") and "_self_attn_" in key: + tmp_length = _get_shape_1(key, tensor, checkpoint) + if key.startswith("lora_te_"): + lora_token_vector_length = tmp_length + elif key.startswith("lora_te1_"): + lora_te1_length = tmp_length + elif key.startswith("lora_te2_"): + lora_te2_length = tmp_length + + if lora_te1_length is not None and lora_te2_length is not None: + lora_token_vector_length = lora_te1_length + lora_te2_length + + if lora_token_vector_length is not None: + break + + return lora_token_vector_length + + +def _fast_safetensors_reader(path: str): + checkpoint = dict() + device = torch.device("meta") + with open(path, "rb") as f: + definition_len = int.from_bytes(f.read(8), "little") + definition_json = f.read(definition_len) + definition = json.loads(definition_json) + + if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in { + "pt", + "torch", + "pytorch", + }: + raise Exception("Supported only pytorch safetensors files") + definition.pop("__metadata__", None) + + for key, info in definition.items(): + dtype = { + "I8": torch.int8, + "I16": torch.int16, + "I32": torch.int32, + "I64": torch.int64, + "F16": torch.float16, + "F32": torch.float32, + "F64": torch.float64, + }[info["dtype"]] + + checkpoint[key] = torch.empty(info["shape"], dtype=dtype, device=device) + + return checkpoint + + +def read_checkpoint_meta(path: Union[str, Path], scan: bool = False): + if str(path).endswith(".safetensors"): + try: + checkpoint = _fast_safetensors_reader(str(path)) + except Exception: + # TODO: create issue for support "meta"? + checkpoint = safetensors.torch.load_file(path, device="cpu") + else: + if scan: + scan_result = scan_file_path(path) + if scan_result.infected_files != 0: + raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.') + checkpoint = torch.load(path, map_location=torch.device("meta")) + return checkpoint diff --git a/invokeai/backend/training/textual_inversion_training.py b/invokeai/backend/training/textual_inversion_training.py index 153bd0fcc4b..5cd77119d86 100644 --- a/invokeai/backend/training/textual_inversion_training.py +++ b/invokeai/backend/training/textual_inversion_training.py @@ -11,6 +11,7 @@ import math import os import random +import re from pathlib import Path from typing import Optional @@ -41,8 +42,8 @@ # invokeai stuff from invokeai.app.services.config import InvokeAIAppConfig, PagingArgumentParser -from invokeai.app.services.model_manager_service import ModelManagerService -from invokeai.backend.model_management.models import SubModelType +from invokeai.app.services.model_manager_service import BaseModelType, ModelManagerService, ModelType +from invokeai.backend.model_manager import SubModelType if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): PIL_INTERPOLATION = { @@ -66,7 +67,6 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.10.0.dev0") - logger = get_logger(__name__) @@ -114,7 +114,6 @@ def parse_args(): general_group.add_argument( "--output_dir", type=Path, - default=f"{config.root}/text-inversion-model", help="The output directory where the model predictions and checkpoints will be written.", ) model_group.add_argument( @@ -550,8 +549,11 @@ def do_textual_inversion_training( local_rank = env_local_rank # setting up things the way invokeai expects them + output_dir = output_dir or config.root_path / "text-inversion-output" + + print(f"output_dir={output_dir}") if not os.path.isabs(output_dir): - output_dir = os.path.join(config.root, output_dir) + output_dir = Path(config.root, output_dir) logging_dir = output_dir / logging_dir @@ -564,14 +566,15 @@ def do_textual_inversion_training( project_config=accelerator_config, ) - model_manager = ModelManagerService(config, logger) + model_manager = ModelManagerService(config) + # The InvokeAI logger already does this... # Make one log on every process with the configuration for debugging. - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO, - ) + # logging.basicConfig( + # format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + # datefmt="%m/%d/%Y %H:%M:%S", + # level=logging.INFO, + # ) logger.info(accelerator.state, main_process_only=False) if accelerator.is_local_main_process: datasets.utils.logging.set_verbosity_warning() @@ -603,17 +606,30 @@ def do_textual_inversion_training( elif output_dir is not None: os.makedirs(output_dir, exist_ok=True) - known_models = model_manager.model_names() - model_name = model.split("/")[-1] - model_meta = next((mm for mm in known_models if mm[0].endswith(model_name)), None) - assert model_meta is not None, f"Unknown model: {model}" - model_info = model_manager.model_info(*model_meta) - assert model_info["model_format"] == "diffusers", "This script only works with models of type 'diffusers'" - tokenizer_info = model_manager.get_model(*model_meta, submodel=SubModelType.Tokenizer) - noise_scheduler_info = model_manager.get_model(*model_meta, submodel=SubModelType.Scheduler) - text_encoder_info = model_manager.get_model(*model_meta, submodel=SubModelType.TextEncoder) - vae_info = model_manager.get_model(*model_meta, submodel=SubModelType.Vae) - unet_info = model_manager.get_model(*model_meta, submodel=SubModelType.UNet) + if len(model) == 32 and re.match(r"^[0-9a-f]+$", model): # looks like a key, not a model name + model_key = model + else: + parts = model.split("/") + if len(parts) == 3: + base_model, model_type, model_name = parts + else: + model_name = parts[-1] + base_model = BaseModelType("sd-1") + model_type = ModelType.Main + models = model_manager.list_models( + model_name=model_name, + base_model=base_model, + model_type=model_type, + ) + assert len(models) > 0, f"Unknown model: {model}" + assert len(models) < 2, "More than one model named {model_name}. Please pass key instead." + model_key = models[0].key + + tokenizer_info = model_manager.get_model(model_key, submodel_type=SubModelType.Tokenizer) + noise_scheduler_info = model_manager.get_model(model_key, submodel_type=SubModelType.Scheduler) + text_encoder_info = model_manager.get_model(model_key, submodel_type=SubModelType.TextEncoder) + vae_info = model_manager.get_model(model_key, submodel_type=SubModelType.Vae) + unet_info = model_manager.get_model(model_key, submodel_type=SubModelType.UNet) pipeline_args = dict(local_files_only=True) if tokenizer_name: diff --git a/invokeai/backend/util/__init__.py b/invokeai/backend/util/__init__.py index 601aab00cbb..186d842723a 100644 --- a/invokeai/backend/util/__init__.py +++ b/invokeai/backend/util/__init__.py @@ -1,6 +1,8 @@ """ Initialization file for invokeai.backend.util """ +from logging import Logger # noqa: F401 + from .attention import auto_detect_slice_size # noqa: F401 from .devices import ( # noqa: F401 CPU_DEVICE, @@ -11,4 +13,13 @@ normalize_device, torch_dtype, ) -from .util import Chdir, ask_user, download_with_resume, instantiate_from_config, url_attachment_name # noqa: F401 +from .logging import InvokeAILogger # noqa: F401 +from .util import ( # noqa: F401 + GIG, + Chdir, + ask_user, + directory_size, + download_with_resume, + instantiate_from_config, + url_attachment_name, +) diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index 84ca7ee02bf..703d6b6c835 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -2,7 +2,7 @@ import platform from contextlib import nullcontext -from typing import Union +from typing import Literal, Union import torch from packaging import version @@ -42,6 +42,13 @@ def choose_precision(device: torch.device) -> str: return "float32" +def get_precision() -> Literal["float16", "float32"]: + device = torch.device(choose_torch_device()) + precision = choose_precision(device) if config.precision == "auto" else config.precision + assert precision in ["float16", "float32"] + return precision + + def torch_dtype(device: torch.device) -> torch.dtype: if config.full_precision: return torch.float32 diff --git a/invokeai/backend/util/logging.py b/invokeai/backend/util/logging.py index 3c829a1a02e..8d763f81128 100644 --- a/invokeai/backend/util/logging.py +++ b/invokeai/backend/util/logging.py @@ -180,6 +180,7 @@ import urllib.parse from abc import abstractmethod from pathlib import Path +from typing import Dict from invokeai.app.services.config import InvokeAIAppConfig @@ -293,7 +294,7 @@ class InvokeAILegacyLogFormatter(InvokeAIFormatter): } def log_fmt(self, levelno: int) -> str: - return self.FORMATS.get(levelno) + return self.FORMATS[levelno] class InvokeAIPlainLogFormatter(InvokeAIFormatter): @@ -332,7 +333,7 @@ class InvokeAIColorLogFormatter(InvokeAIFormatter): } def log_fmt(self, levelno: int) -> str: - return self.FORMATS.get(levelno) + return self.FORMATS[levelno] LOG_FORMATTERS = { @@ -344,17 +345,19 @@ def log_fmt(self, levelno: int) -> str: class InvokeAILogger(object): - loggers = dict() + loggers: Dict[str, logging.Logger] = dict() @classmethod def get_logger( cls, name: str = "InvokeAI", config: InvokeAIAppConfig = InvokeAIAppConfig.get_config() ) -> logging.Logger: + """Return a logger appropriately configured for the current InvokeAI configuration.""" if name in cls.loggers: logger = cls.loggers[name] logger.handlers.clear() else: logger = logging.getLogger(name) + config = config or InvokeAIAppConfig.get_config() # in case None is passed logger.setLevel(config.log_level.upper()) # yes, strings work here for ch in cls.get_loggers(config): logger.addHandler(ch) diff --git a/invokeai/backend/util/test_utils.py b/invokeai/backend/util/test_utils.py index 1c7b5388824..0d2690ee985 100644 --- a/invokeai/backend/util/test_utils.py +++ b/invokeai/backend/util/test_utils.py @@ -6,9 +6,10 @@ import torch from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig -from invokeai.backend.install.model_install_backend import ModelInstall -from invokeai.backend.model_management.model_manager import ModelInfo -from invokeai.backend.model_management.models.base import BaseModelType, ModelNotFoundException, ModelType, SubModelType +from invokeai.app.services.model_install_service import ModelInstallService +from invokeai.app.services.model_record_service import ModelRecordServiceBase +from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType, UnknownModelException +from invokeai.backend.model_manager.loader import ModelInfo, ModelLoad @pytest.fixture(scope="session") @@ -24,11 +25,16 @@ def model_installer(): # which can cause `install_and_load_model(...)` to re-download the model unnecessarily. As a temporary workaround, # we pass a kwarg to get_config, which causes the config to be re-loaded. To fix this properly, we should stop using # a singleton. - return ModelInstall(InvokeAIAppConfig.get_config(log_level="info")) + # + # REPLY(lstein): Don't use get_config() here. Just use the regular pydantic constructor. + # + config = InvokeAIAppConfig(log_level="info") + model_store = ModelRecordServiceBase.open(config) + return ModelInstallService(store=model_store, config=config) def install_and_load_model( - model_installer: ModelInstall, + model_installer: ModelInstallService, model_path_id_or_url: Union[str, Path], model_name: str, base_model: BaseModelType, @@ -52,15 +58,19 @@ def install_and_load_model( ModelInfo """ # If the requested model is already installed, return its ModelInfo. - with contextlib.suppress(ModelNotFoundException): - return model_installer.mgr.get_model(model_name, base_model, model_type, submodel_type) + loader = ModelLoad(config=model_installer.config, store=model_installer.store) + with contextlib.suppress(UnknownModelException): + model = model_installer.store.model_info_by_name(model_name, base_model, model_type) + return loader.get_model(model.key, submodel_type) # Install the requested model. - model_installer.heuristic_import(model_path_id_or_url) + model_installer.install(model_path_id_or_url) + model_installer.wait_for_installs() try: - return model_installer.mgr.get_model(model_name, base_model, model_type, submodel_type) - except ModelNotFoundException as e: + model = model_installer.store.model_info_by_name(model_name, base_model, model_type) + return loader.get_model(model.key, submodel_type) + except UnknownModelException as e: raise Exception( "Failed to get model info after installing it. There could be a mismatch between the requested model and" f" the installation id ('{model_path_id_or_url}'). Error: {e}" diff --git a/invokeai/backend/util/util.py b/invokeai/backend/util/util.py index 0796f1a8cdb..78d7410fc52 100644 --- a/invokeai/backend/util/util.py +++ b/invokeai/backend/util/util.py @@ -2,14 +2,11 @@ import importlib import io import math -import multiprocessing as mp import os import re -from collections import abc from inspect import isfunction from pathlib import Path -from queue import Queue -from threading import Thread +from typing import Optional import numpy as np import requests @@ -21,6 +18,9 @@ from .devices import torch_dtype +# actual size of a gig +GIG = 1073741824 + def log_txt_as_img(wh, xc, size=10): # wh a tuple of (width, height) @@ -101,112 +101,6 @@ def get_obj_from_str(string, reload=False): return getattr(importlib.import_module(module, package=None), cls) -def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): - # create dummy dataset instance - - # run prefetching - if idx_to_fn: - res = func(data, worker_id=idx) - else: - res = func(data) - Q.put([idx, res]) - Q.put("Done") - - -def parallel_data_prefetch( - func: callable, - data, - n_proc, - target_data_type="ndarray", - cpu_intensive=True, - use_worker_id=False, -): - # if target_data_type not in ["ndarray", "list"]: - # raise ValueError( - # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." - # ) - if isinstance(data, np.ndarray) and target_data_type == "list": - raise ValueError("list expected but function got ndarray.") - elif isinstance(data, abc.Iterable): - if isinstance(data, dict): - logger.warning( - '"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' - ) - data = list(data.values()) - if target_data_type == "ndarray": - data = np.asarray(data) - else: - data = list(data) - else: - raise TypeError( - f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." - ) - - if cpu_intensive: - Q = mp.Queue(1000) - proc = mp.Process - else: - Q = Queue(1000) - proc = Thread - # spawn processes - if target_data_type == "ndarray": - arguments = [[func, Q, part, i, use_worker_id] for i, part in enumerate(np.array_split(data, n_proc))] - else: - step = int(len(data) / n_proc + 1) if len(data) % n_proc != 0 else int(len(data) / n_proc) - arguments = [ - [func, Q, part, i, use_worker_id] - for i, part in enumerate([data[i : i + step] for i in range(0, len(data), step)]) - ] - processes = [] - for i in range(n_proc): - p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) - processes += [p] - - # start processes - logger.info("Start prefetching...") - import time - - start = time.time() - gather_res = [[] for _ in range(n_proc)] - try: - for p in processes: - p.start() - - k = 0 - while k < n_proc: - # get result - res = Q.get() - if res == "Done": - k += 1 - else: - gather_res[res[0]] = res[1] - - except Exception as e: - logger.error("Exception: ", e) - for p in processes: - p.terminate() - - raise e - finally: - for p in processes: - p.join() - logger.info(f"Prefetching complete. [{time.time() - start} sec.]") - - if target_data_type == "ndarray": - if not isinstance(gather_res[0], np.ndarray): - return np.concatenate([np.asarray(r) for r in gather_res], axis=0) - - # order outputs - return np.concatenate(gather_res, axis=0) - elif target_data_type == "list": - out = [] - for r in gather_res: - out.extend(r) - return out - else: - return gather_res - - def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3): delta = (res[0] / shape[0], res[1] / shape[1]) d = (shape[0] // res[0], shape[1] // res[1]) @@ -269,7 +163,7 @@ def ask_user(question: str, answers: list): # ------------------------------------- -def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path: +def download_with_resume(url: str, dest: Path, access_token: str = None) -> Optional[Path]: """ Download a model file. :param url: https, http or ftp URL @@ -286,10 +180,7 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path content_length = int(resp.headers.get("content-length", 0)) if dest.is_dir(): - try: - file_name = re.search('filename="(.+)"', resp.headers.get("Content-Disposition")).group(1) - except AttributeError: - file_name = os.path.basename(url) + file_name = response_attachment(resp) or os.path.basename(url) dest = dest / file_name else: dest.parent.mkdir(parents=True, exist_ok=True) @@ -338,15 +229,24 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path return dest -def url_attachment_name(url: str) -> dict: +def response_attachment(response: requests.Response) -> Optional[str]: try: - resp = requests.get(url, stream=True) - match = re.search('filename="(.+)"', resp.headers.get("Content-Disposition")) - return match.group(1) + if disposition := response.headers.get("Content-Disposition"): + if match := re.search('filename="(.+)"', disposition): + return match.group(1) + return None except Exception: return None +def url_attachment_name(url: str) -> Optional[str]: + resp = requests.get(url) + if resp.ok: + return response_attachment(resp) + else: + return None + + def download_with_progress_bar(url: str, dest: Path) -> bool: result = download_with_resume(url, dest, access_token=None) return result is not None @@ -363,6 +263,19 @@ def image_to_dataURL(image: Image.Image, image_format: str = "PNG") -> str: return image_base64 +def directory_size(directory: Path) -> int: + """ + Returns the aggregate size of all files in a directory (bytes). + """ + sum = 0 + for root, dirs, files in os.walk(directory): + for f in files: + sum += Path(root, f).stat().st_size + for d in dirs: + sum += Path(root, d).stat().st_size + return sum + + class Chdir(object): """Context manager to chdir to desired directory and change back after context exits: Args: diff --git a/invokeai/configs/INITIAL_MODELS.yaml b/invokeai/configs/INITIAL_MODELS.yaml index b6883ea9151..9f191504bc5 100644 --- a/invokeai/configs/INITIAL_MODELS.yaml +++ b/invokeai/configs/INITIAL_MODELS.yaml @@ -1,156 +1,157 @@ # This file predefines a few models that the user may want to install. sd-1/main/stable-diffusion-v1-5: description: Stable Diffusion version 1.5 diffusers model (4.27 GB) - repo_id: runwayml/stable-diffusion-v1-5 + source: runwayml/stable-diffusion-v1-5 recommended: True default: True sd-1/main/stable-diffusion-v1-5-inpainting: description: RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB) - repo_id: runwayml/stable-diffusion-inpainting + source: runwayml/stable-diffusion-inpainting recommended: True sd-2/main/stable-diffusion-2-1: description: Stable Diffusion version 2.1 diffusers model, trained on 768 pixel images (5.21 GB) - repo_id: stabilityai/stable-diffusion-2-1 + source: stabilityai/stable-diffusion-2-1 recommended: False sd-2/main/stable-diffusion-2-inpainting: description: Stable Diffusion version 2.0 inpainting model (5.21 GB) - repo_id: stabilityai/stable-diffusion-2-inpainting + source: stabilityai/stable-diffusion-2-inpainting recommended: False sdxl/main/stable-diffusion-xl-base-1-0: description: Stable Diffusion XL base model (12 GB) - repo_id: stabilityai/stable-diffusion-xl-base-1.0 + source: stabilityai/stable-diffusion-xl-base-1.0 recommended: True sdxl-refiner/main/stable-diffusion-xl-refiner-1-0: description: Stable Diffusion XL refiner model (12 GB) - repo_id: stabilityai/stable-diffusion-xl-refiner-1.0 + source: stabilityai/stable-diffusion-xl-refiner-1.0 recommended: False -sdxl/vae/sdxl-1-0-vae-fix: - description: Fine tuned version of the SDXL-1.0 VAE - repo_id: madebyollin/sdxl-vae-fp16-fix +sdxl/vae/sdxl-vae-fp16-fix: + description: Version of the SDXL-1.0 VAE that works in half precision mode + source: madebyollin/sdxl-vae-fp16-fix recommended: True sd-1/main/Analog-Diffusion: description: An SD-1.5 model trained on diverse analog photographs (2.13 GB) - repo_id: wavymulder/Analog-Diffusion + source: wavymulder/Analog-Diffusion recommended: False sd-1/main/Deliberate: description: Versatile model that produces detailed images up to 768px (4.27 GB) - repo_id: XpucT/Deliberate + source: XpucT/Deliberate recommended: False sd-1/main/Dungeons-and-Diffusion: description: Dungeons & Dragons characters (2.13 GB) - repo_id: 0xJustin/Dungeons-and-Diffusion + source: 0xJustin/Dungeons-and-Diffusion recommended: False sd-1/main/dreamlike-photoreal-2: description: A photorealistic model trained on 768 pixel images based on SD 1.5 (2.13 GB) - repo_id: dreamlike-art/dreamlike-photoreal-2.0 + source: dreamlike-art/dreamlike-photoreal-2.0 recommended: False sd-1/main/Inkpunk-Diffusion: description: Stylized illustrations inspired by Gorillaz, FLCL and Shinkawa; prompt with "nvinkpunk" (4.27 GB) - repo_id: Envvi/Inkpunk-Diffusion + source: Envvi/Inkpunk-Diffusion recommended: False sd-1/main/openjourney: description: An SD 1.5 model fine tuned on Midjourney; prompt with "mdjrny-v4 style" (2.13 GB) - repo_id: prompthero/openjourney + source: prompthero/openjourney recommended: False sd-1/main/seek.art_MEGA: - repo_id: coreco/seek.art_MEGA + source: coreco/seek.art_MEGA description: A general use SD-1.5 "anything" model that supports multiple styles (2.1 GB) recommended: False sd-1/main/trinart_stable_diffusion_v2: description: An SD-1.5 model finetuned with ~40K assorted high resolution manga/anime-style images (2.13 GB) - repo_id: naclbit/trinart_stable_diffusion_v2 + source: naclbit/trinart_stable_diffusion_v2 recommended: False sd-1/controlnet/qrcode_monster: - repo_id: monster-labs/control_v1p_sd15_qrcode_monster + source: monster-labs/control_v1p_sd15_qrcode_monster subfolder: v2 sd-1/controlnet/canny: - repo_id: lllyasviel/control_v11p_sd15_canny + source: lllyasviel/control_v11p_sd15_canny recommended: True sd-1/controlnet/inpaint: - repo_id: lllyasviel/control_v11p_sd15_inpaint + source: lllyasviel/control_v11p_sd15_inpaint sd-1/controlnet/mlsd: - repo_id: lllyasviel/control_v11p_sd15_mlsd + source: lllyasviel/control_v11p_sd15_mlsd sd-1/controlnet/depth: - repo_id: lllyasviel/control_v11f1p_sd15_depth + source: lllyasviel/control_v11f1p_sd15_depth recommended: True sd-1/controlnet/normal_bae: - repo_id: lllyasviel/control_v11p_sd15_normalbae + source: lllyasviel/control_v11p_sd15_normalbae sd-1/controlnet/seg: - repo_id: lllyasviel/control_v11p_sd15_seg + source: lllyasviel/control_v11p_sd15_seg sd-1/controlnet/lineart: - repo_id: lllyasviel/control_v11p_sd15_lineart + source: lllyasviel/control_v11p_sd15_lineart recommended: True sd-1/controlnet/lineart_anime: - repo_id: lllyasviel/control_v11p_sd15s2_lineart_anime + source: lllyasviel/control_v11p_sd15s2_lineart_anime sd-1/controlnet/openpose: - repo_id: lllyasviel/control_v11p_sd15_openpose + source: lllyasviel/control_v11p_sd15_openpose recommended: True sd-1/controlnet/scribble: - repo_id: lllyasviel/control_v11p_sd15_scribble + source: lllyasviel/control_v11p_sd15_scribble recommended: False sd-1/controlnet/softedge: - repo_id: lllyasviel/control_v11p_sd15_softedge + source: lllyasviel/control_v11p_sd15_softedge sd-1/controlnet/shuffle: - repo_id: lllyasviel/control_v11e_sd15_shuffle + source: lllyasviel/control_v11e_sd15_shuffle sd-1/controlnet/tile: - repo_id: lllyasviel/control_v11f1e_sd15_tile + source: lllyasviel/control_v11f1e_sd15_tile sd-1/controlnet/ip2p: - repo_id: lllyasviel/control_v11e_sd15_ip2p + source: lllyasviel/control_v11e_sd15_ip2p sd-1/t2i_adapter/canny-sd15: - repo_id: TencentARC/t2iadapter_canny_sd15v2 + source: TencentARC/t2iadapter_canny_sd15v2 sd-1/t2i_adapter/sketch-sd15: - repo_id: TencentARC/t2iadapter_sketch_sd15v2 + source: TencentARC/t2iadapter_sketch_sd15v2 sd-1/t2i_adapter/depth-sd15: - repo_id: TencentARC/t2iadapter_depth_sd15v2 + source: TencentARC/t2iadapter_depth_sd15v2 sd-1/t2i_adapter/zoedepth-sd15: - repo_id: TencentARC/t2iadapter_zoedepth_sd15v1 + source: TencentARC/t2iadapter_zoedepth_sd15v1 sdxl/t2i_adapter/canny-sdxl: - repo_id: TencentARC/t2i-adapter-canny-sdxl-1.0 + source: TencentARC/t2i-adapter-canny-sdxl-1.0 sdxl/t2i_adapter/zoedepth-sdxl: - repo_id: TencentARC/t2i-adapter-depth-zoe-sdxl-1.0 + source: TencentARC/t2i-adapter-depth-zoe-sdxl-1.0 sdxl/t2i_adapter/lineart-sdxl: - repo_id: TencentARC/t2i-adapter-lineart-sdxl-1.0 + source: TencentARC/t2i-adapter-lineart-sdxl-1.0 sdxl/t2i_adapter/sketch-sdxl: - repo_id: TencentARC/t2i-adapter-sketch-sdxl-1.0 + source: TencentARC/t2i-adapter-sketch-sdxl-1.0 sd-1/embedding/EasyNegative: - path: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors + source: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors recommended: True -sd-1/embedding/ahx-beta-453407d: - repo_id: sd-concepts-library/ahx-beta-453407d + description: A textual inversion to use in the negative prompt to reduce bad anatomy sd-1/lora/LowRA: - path: https://civitai.com/api/download/models/63006 + source: https://civitai.com/api/download/models/63006 recommended: True + description: An embedding that helps generate low-light images sd-1/lora/Ink scenery: - path: https://civitai.com/api/download/models/83390 + source: https://civitai.com/api/download/models/83390 + description: Generate india ink-like landscapes sd-1/ip_adapter/ip_adapter_sd15: - repo_id: InvokeAI/ip_adapter_sd15 + source: InvokeAI/ip_adapter_sd15 recommended: True requires: - InvokeAI/ip_adapter_sd_image_encoder description: IP-Adapter for SD 1.5 models sd-1/ip_adapter/ip_adapter_plus_sd15: - repo_id: InvokeAI/ip_adapter_plus_sd15 + source: InvokeAI/ip_adapter_plus_sd15 recommended: False requires: - InvokeAI/ip_adapter_sd_image_encoder description: Refined IP-Adapter for SD 1.5 models sd-1/ip_adapter/ip_adapter_plus_face_sd15: - repo_id: InvokeAI/ip_adapter_plus_face_sd15 + source: InvokeAI/ip_adapter_plus_face_sd15 recommended: False requires: - InvokeAI/ip_adapter_sd_image_encoder description: Refined IP-Adapter for SD 1.5 models, adapted for faces sdxl/ip_adapter/ip_adapter_sdxl: - repo_id: InvokeAI/ip_adapter_sdxl + source: InvokeAI/ip_adapter_sdxl recommended: False requires: - InvokeAI/ip_adapter_sdxl_image_encoder description: IP-Adapter for SDXL models any/clip_vision/ip_adapter_sd_image_encoder: - repo_id: InvokeAI/ip_adapter_sd_image_encoder + source: InvokeAI/ip_adapter_sd_image_encoder recommended: False description: Required model for using IP-Adapters with SD-1/2 models any/clip_vision/ip_adapter_sdxl_image_encoder: - repo_id: InvokeAI/ip_adapter_sdxl_image_encoder + source: InvokeAI/ip_adapter_sdxl_image_encoder recommended: False description: Required model for using IP-Adapters with SDXL models diff --git a/invokeai/configs/stable-diffusion/v1-inpainting-inference-v.yaml b/invokeai/configs/stable-diffusion/v1-inpainting-inference-v.yaml new file mode 100644 index 00000000000..2399f62afd9 --- /dev/null +++ b/invokeai/configs/stable-diffusion/v1-inpainting-inference-v.yaml @@ -0,0 +1,80 @@ +model: + base_learning_rate: 7.5e-05 + target: invokeai.backend.models.diffusion.ddpm.LatentInpaintDiffusion + params: + parameterization: "v" + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: hybrid # important + monitor: val/loss_simple_ema + scale_factor: 0.18215 + finetune_keys: null + + scheduler_config: # 10000 warmup steps + target: invokeai.backend.stable_diffusion.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + personalization_config: + target: invokeai.backend.stable_diffusion.embedding_manager.EmbeddingManager + params: + placeholder_strings: ["*"] + initializer_words: ['sculpture'] + per_image_tokens: false + num_vectors_per_token: 8 + progressive_words: False + + unet_config: + target: invokeai.backend.stable_diffusion.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 9 # 4 data + 4 downscaled image + 1 mask + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: invokeai.backend.stable_diffusion.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: invokeai.backend.stable_diffusion.encoders.modules.WeightedFrozenCLIPEmbedder diff --git a/invokeai/frontend/install/model_install.py b/invokeai/frontend/install/model_install.py index 1fb3b618917..042336303a0 100644 --- a/invokeai/frontend/install/model_install.py +++ b/invokeai/frontend/install/model_install.py @@ -6,28 +6,29 @@ """ This is the npyscreen frontend to the model installation application. -The work is actually done in backend code in model_install_backend.py. """ import argparse import curses -import logging import sys -import textwrap import traceback from argparse import Namespace -from multiprocessing import Process -from multiprocessing.connection import Connection, Pipe +from dataclasses import dataclass, field from pathlib import Path from shutil import get_terminal_size +from typing import Dict, List, Optional, Tuple import npyscreen +import omegaconf import torch from npyscreen import widget +from pydantic import BaseModel +import invokeai.configs as configs from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.install.model_install_backend import InstallSelections, ModelInstall, SchedulerPredictionType -from invokeai.backend.model_management import ModelManager, ModelType +from invokeai.app.services.model_install_service import ModelInstallJob, ModelInstallService +from invokeai.backend.install.install_helper import InstallHelper, UnifiedModelInfo +from invokeai.backend.model_manager import BaseModelType, ModelType from invokeai.backend.util import choose_precision, choose_torch_device from invokeai.backend.util.logging import InvokeAILogger from invokeai.frontend.install.widgets import ( @@ -40,7 +41,6 @@ SingleSelectColumns, TextBox, WindowTooSmallException, - select_stable_diffusion_config_file, set_min_terminal_size, ) @@ -56,12 +56,20 @@ MAX_OTHER_MODELS = 72 +@dataclass +class InstallSelections: + install_models: List[UnifiedModelInfo] = field(default_factory=list) + remove_models: List[str] = field(default_factory=list) + + def make_printable(s: str) -> str: - """Replace non-printable characters in a string""" + """Replace non-printable characters in a string.""" return s.translate(NOPRINT_TRANS_TABLE) class addModelsForm(CyclingForm, npyscreen.FormMultiPage): + """Main form for interactive TUI.""" + # for responsive resizing set to False, but this seems to cause a crash! FIX_MINIMUM_SIZE_WHEN_CREATED = True @@ -74,17 +82,12 @@ def __init__(self, parentApp, name, multipage=False, *args, **keywords): super().__init__(parentApp=parentApp, name=name, *args, **keywords) def create(self): + self.installer = self.parentApp.install_helper.installer + self.model_labels = self._get_model_labels() self.keypress_timeout = 10 self.counter = 0 self.subprocess_connection = None - if not config.model_conf_path.exists(): - with open(config.model_conf_path, "w") as file: - print("# InvokeAI model configuration file", file=file) - self.installer = ModelInstall(config) - self.all_models = self.installer.all_models() - self.starter_models = self.installer.starter_models() - self.model_labels = self._get_model_labels() window_width, window_height = get_terminal_size() self.nextrely -= 1 @@ -161,15 +164,7 @@ def create(self): self.nextrely = bottom_of_table + 1 - self.monitor = self.add_widget_intelligent( - BufferBox, - name="Log Messages", - editable=False, - max_height=6, - ) - self.nextrely += 1 - done_label = "APPLY CHANGES" back_label = "BACK" cancel_label = "CANCEL" current_position = self.nextrely @@ -185,14 +180,8 @@ def create(self): npyscreen.ButtonPress, name=cancel_label, when_pressed_function=self.on_cancel ) self.nextrely = current_position - self.ok_button = self.add_widget_intelligent( - npyscreen.ButtonPress, - name=done_label, - relx=(window_width - len(done_label)) // 2, - when_pressed_function=self.on_execute, - ) - label = "APPLY CHANGES & EXIT" + label = "APPLY CHANGES" self.nextrely = current_position self.done = self.add_widget_intelligent( npyscreen.ButtonPress, @@ -210,16 +199,15 @@ def create(self): def add_starter_pipelines(self) -> dict[str, npyscreen.widget]: """Add widgets responsible for selecting diffusers models""" widgets = dict() - models = self.all_models - starters = self.starter_models - starter_model_labels = self.model_labels - self.installed_models = sorted([x for x in starters if models[x].installed]) + all_models = self.all_models # master dict of all models, indexed by key + model_list = [x for x in self.starter_models if all_models[x].model_type in ["main", "vae"]] + model_labels = [self.model_labels[x] for x in model_list] widgets.update( label1=self.add_widget_intelligent( CenteredTitleText, - name="Select from a starter set of Stable Diffusion models from HuggingFace.", + name="Select from a starter set of Stable Diffusion models from HuggingFace and Civitae.", editable=False, labelColor="CAUTION", ) @@ -229,23 +217,24 @@ def add_starter_pipelines(self) -> dict[str, npyscreen.widget]: # if user has already installed some initial models, then don't patronize them # by showing more recommendations show_recommended = len(self.installed_models) == 0 - keys = [x for x in models.keys() if x in starters] + + checked = [ + model_list.index(x) + for x in model_list + if (show_recommended and all_models[x].recommended) or all_models[x].installed + ] widgets.update( models_selected=self.add_widget_intelligent( MultiSelectColumns, columns=1, name="Install Starter Models", - values=[starter_model_labels[x] for x in keys], - value=[ - keys.index(x) - for x in keys - if (show_recommended and models[x].recommended) or (x in self.installed_models) - ], - max_height=len(starters) + 1, + values=model_labels, + value=checked, + max_height=len(model_list) + 1, relx=4, scroll_exit=True, ), - models=keys, + models=model_list, ) self.nextrely += 1 @@ -261,7 +250,8 @@ def add_model_widgets( ) -> dict[str, npyscreen.widget]: """Generic code to create model selection widgets""" widgets = dict() - model_list = [x for x in self.all_models if self.all_models[x].model_type == model_type and x not in exclude] + all_models = self.all_models + model_list = [x for x in all_models if all_models[x].model_type == model_type and x not in exclude] model_labels = [self.model_labels[x] for x in model_list] show_recommended = len(self.installed_models) == 0 @@ -297,7 +287,7 @@ def add_model_widgets( value=[ model_list.index(x) for x in model_list - if (show_recommended and self.all_models[x].recommended) or self.all_models[x].installed + if (show_recommended and all_models[x].recommended) or all_models[x].installed ], max_height=len(model_list) // columns + 1, relx=4, @@ -321,7 +311,7 @@ def add_model_widgets( download_ids=self.add_widget_intelligent( TextBox, name="Additional URLs, or HuggingFace repo_ids to install (Space separated. Use shift-control-V to paste):", - max_height=4, + max_height=6, scroll_exit=True, editable=True, ) @@ -349,8 +339,7 @@ def add_pipeline_widgets( def resize(self): super().resize() if s := self.starter_pipelines.get("models_selected"): - keys = [x for x in self.all_models.keys() if x in self.starter_models] - s.values = [self.model_labels[x] for x in keys] + s.values = [self.model_labels[x] for x in self.starter_pipelines.get("models")] def _toggle_tables(self, value=None): selected_tab = value[0] @@ -382,17 +371,18 @@ def _toggle_tables(self, value=None): self.display() def _get_model_labels(self) -> dict[str, str]: + """Return a list of trimmed labels for all models.""" window_width, window_height = get_terminal_size() checkbox_width = 4 spacing_width = 2 + result = dict() models = self.all_models - label_width = max([len(models[x].name) for x in models]) + label_width = max([len(models[x].name) for x in self.starter_models]) description_width = window_width - label_width - checkbox_width - spacing_width - result = dict() - for x in models.keys(): - description = models[x].description + for key in self.all_models: + description = models[key].description description = ( description[0 : description_width - 3] + "..." if description and len(description) > description_width @@ -400,7 +390,8 @@ def _get_model_labels(self) -> dict[str, str]: if description else "" ) - result[x] = f"%-{label_width}s %s" % (models[x].name, description) + result[key] = f"%-{label_width}s %s" % (models[key].name, description) + return result def _get_columns(self) -> int: @@ -411,38 +402,24 @@ def _get_columns(self) -> int: def confirm_deletions(self, selections: InstallSelections) -> bool: remove_models = selections.remove_models if len(remove_models) > 0: - mods = "\n".join([ModelManager.parse_key(x)[0] for x in remove_models]) + mods = "\n".join([self.all_models[x].name for x in remove_models]) return npyscreen.notify_ok_cancel( f"These unchecked models will be deleted from disk. Continue?\n---------\n{mods}" ) else: return True - def on_execute(self): - self.marshall_arguments() - app = self.parentApp - if not self.confirm_deletions(app.install_selections): - return + @property + def all_models(self) -> Dict[str, UnifiedModelInfo]: + return self.parentApp.install_helper.all_models - self.monitor.entry_widget.buffer(["Processing..."], scroll_end=True) - self.ok_button.hidden = True - self.display() + @property + def starter_models(self) -> List[str]: + return self.parentApp.install_helper._starter_models - # TO DO: Spawn a worker thread, not a subprocess - parent_conn, child_conn = Pipe() - p = Process( - target=process_and_execute, - kwargs=dict( - opt=app.program_opts, - selections=app.install_selections, - conn_out=child_conn, - ), - ) - p.start() - child_conn.close() - self.subprocess_connection = parent_conn - self.subprocess = p - app.install_selections = InstallSelections() + @property + def installed_models(self) -> List[str]: + return self.parentApp.install_helper._installed_models def on_back(self): self.parentApp.switchFormPrevious() @@ -461,76 +438,6 @@ def on_done(self): self.parentApp.user_cancelled = False self.editing = False - ########## This routine monitors the child process that is performing model installation and removal ##### - def while_waiting(self): - """Called during idle periods. Main task is to update the Log Messages box with messages - from the child process that does the actual installation/removal""" - c = self.subprocess_connection - if not c: - return - - monitor_widget = self.monitor.entry_widget - while c.poll(): - try: - data = c.recv_bytes().decode("utf-8") - data.strip("\n") - - # processing child is requesting user input to select the - # right configuration file - if data.startswith("*need v2 config"): - _, model_path, *_ = data.split(":", 2) - self._return_v2_config(model_path) - - # processing child is done - elif data == "*done*": - self._close_subprocess_and_regenerate_form() - break - - # update the log message box - else: - data = make_printable(data) - data = data.replace("[A", "") - monitor_widget.buffer( - textwrap.wrap( - data, - width=monitor_widget.width, - subsequent_indent=" ", - ), - scroll_end=True, - ) - self.display() - except (EOFError, OSError): - self.subprocess_connection = None - - def _return_v2_config(self, model_path: str): - c = self.subprocess_connection - model_name = Path(model_path).name - message = select_stable_diffusion_config_file(model_name=model_name) - c.send_bytes(message.encode("utf-8")) - - def _close_subprocess_and_regenerate_form(self): - app = self.parentApp - self.subprocess_connection.close() - self.subprocess_connection = None - self.monitor.entry_widget.buffer(["** Action Complete **"]) - self.display() - - # rebuild the form, saving and restoring some of the fields that need to be preserved. - saved_messages = self.monitor.entry_widget.values - - app.main_form = app.addForm( - "MAIN", - addModelsForm, - name="Install Stable Diffusion Models", - multipage=self.multipage, - ) - app.switchForm("MAIN") - - app.main_form.monitor.entry_widget.values = saved_messages - app.main_form.monitor.entry_widget.buffer([""], scroll_end=True) - # app.main_form.pipeline_models['autoload_directory'].value = autoload_dir - # app.main_form.pipeline_models['autoscan_on_startup'].value = autoscan - def marshall_arguments(self): """ Assemble arguments and store as attributes of the application: @@ -561,16 +468,13 @@ def marshall_arguments(self): models_to_install = [x for x in selected if not self.all_models[x].installed] models_to_remove = [x for x in section["models"] if x not in selected and self.all_models[x].installed] selections.remove_models.extend(models_to_remove) - selections.install_models.extend( - all_models[x].path or all_models[x].repo_id - for x in models_to_install - if all_models[x].path or all_models[x].repo_id - ) + selections.install_models.extend([all_models[x] for x in models_to_install]) # models located in the 'download_ids" section for section in ui_sections: if downloads := section.get("download_ids"): - selections.install_models.extend(downloads.value.split()) + models = [UnifiedModelInfo(source=x) for x in downloads.value.split()] + selections.install_models.extend(models) # NOT NEEDED - DONE IN BACKEND NOW # # special case for the ipadapter_models. If any of the adapters are @@ -593,12 +497,12 @@ def marshall_arguments(self): class AddModelApplication(npyscreen.NPSAppManaged): - def __init__(self, opt): + def __init__(self, opt: Namespace, install_helper: InstallHelper): super().__init__() self.program_opts = opt self.user_cancelled = False - # self.autoload_pending = True self.install_selections = InstallSelections() + self.install_helper = install_helper def onStart(self): npyscreen.setTheme(npyscreen.Themes.DefaultTheme) @@ -610,136 +514,55 @@ def onStart(self): ) -class StderrToMessage: - def __init__(self, connection: Connection): - self.connection = connection - - def write(self, data: str): - self.connection.send_bytes(data.encode("utf-8")) - - def flush(self): - pass - - -# -------------------------------------------------------- -def ask_user_for_prediction_type(model_path: Path, tui_conn: Connection = None) -> SchedulerPredictionType: - if tui_conn: - logger.debug("Waiting for user response...") - return _ask_user_for_pt_tui(model_path, tui_conn) - else: - return _ask_user_for_pt_cmdline(model_path) - - -def _ask_user_for_pt_cmdline(model_path: Path) -> SchedulerPredictionType: - choices = [SchedulerPredictionType.Epsilon, SchedulerPredictionType.VPrediction, None] - print( - f""" -Please select the type of the V2 checkpoint named {model_path.name}: -[1] A model based on Stable Diffusion v2 trained on 512 pixel images (SD-2-base) -[2] A model based on Stable Diffusion v2 trained on 768 pixel images (SD-2-768) -[3] Skip this model and come back later. -""" - ) - choice = None - ok = False - while not ok: - try: - choice = input("select> ").strip() - choice = choices[int(choice) - 1] - ok = True - except (ValueError, IndexError): - print(f"{choice} is not a valid choice") - except EOFError: - return - return choice - - -def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection) -> SchedulerPredictionType: - try: - tui_conn.send_bytes(f"*need v2 config for:{model_path}".encode("utf-8")) - # note that we don't do any status checking here - response = tui_conn.recv_bytes().decode("utf-8") - if response is None: - return None - elif response == "epsilon": - return SchedulerPredictionType.epsilon - elif response == "v": - return SchedulerPredictionType.VPrediction - elif response == "abort": - logger.info("Conversion aborted") - return None - else: - return response - except Exception: - return None - - -# -------------------------------------------------------- -def process_and_execute( - opt: Namespace, - selections: InstallSelections, - conn_out: Connection = None, -): - # need to reinitialize config in subprocess - config = InvokeAIAppConfig.get_config() - args = ["--root", opt.root] if opt.root else [] - config.parse_args(args) - - # set up so that stderr is sent to conn_out - if conn_out: - translator = StderrToMessage(conn_out) - sys.stderr = translator - sys.stdout = translator - logger = InvokeAILogger.get_logger() - logger.handlers.clear() - logger.addHandler(logging.StreamHandler(translator)) - - installer = ModelInstall(config, prediction_type_helper=lambda x: ask_user_for_prediction_type(x, conn_out)) - installer.install(selections) - - if conn_out: - conn_out.send_bytes("*done*".encode("utf-8")) - conn_out.close() +def list_models(installer: ModelInstallService, model_type: ModelType): + """Print out all models of type model_type.""" + models = installer.store.search_by_name(model_type=model_type) + print(f"Installed models of type `{model_type}`:") + for model in models: + path = (config.models_path / model.path).resolve() + print(f"{model.name:40}{model.base_model.value:14}{path}") # -------------------------------------------------------- def select_and_download_models(opt: Namespace): + """Prompt user for install/delete selections and execute.""" precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device())) config.precision = precision - installer = ModelInstall(config, prediction_type_helper=ask_user_for_prediction_type) + install_helper = InstallHelper(config) + installer = install_helper.installer + if opt.list_models: - installer.list_models(opt.list_models) + list_models(installer, opt.list_models) + elif opt.add or opt.delete: - selections = InstallSelections(install_models=opt.add or [], remove_models=opt.delete or []) - installer.install(selections) + selections = InstallSelections( + install_models=[UnifiedModelInfo(source=x) for x in (opt.add or [])], remove_models=opt.delete or [] + ) + install_helper.add_or_delete(selections) + elif opt.default_only: - selections = InstallSelections(install_models=installer.default_model()) - installer.install(selections) + selections = InstallSelections(install_models=[initial_models.default_model()]) + install_helper.add_or_delete(selections) + elif opt.yes_to_all: - selections = InstallSelections(install_models=installer.recommended_models()) - installer.install(selections) + selections = InstallSelections(install_models=initial_models.recommended_models()) + install_helper.add_or_delete(selections) # this is where the TUI is called else: - # needed to support the probe() method running under a subprocess - torch.multiprocessing.set_start_method("spawn") - if not set_min_terminal_size(MIN_COLS, MIN_LINES): raise WindowTooSmallException( "Could not increase terminal size. Try running again with a larger window or smaller font size." ) - installApp = AddModelApplication(opt) + installApp = AddModelApplication(opt, install_helper) try: installApp.run() except KeyboardInterrupt as e: - if hasattr(installApp, "main_form"): - if installApp.main_form.subprocess and installApp.main_form.subprocess.is_alive(): - logger.info("Terminating subprocesses") - installApp.main_form.subprocess.terminate() - installApp.main_form.subprocess = None - raise e - process_and_execute(opt, installApp.install_selections) + print("Aborted...") + sys.exit(-1) + + install_helper.add_or_delete(installApp.install_selections) # ------------------------------------- @@ -753,7 +576,7 @@ def main(): parser.add_argument( "--delete", nargs="*", - help="List of names of models to idelete", + help="List of names of models to delete. Use type:name to disambiguate, as in `controlnet:my_model`", ) parser.add_argument( "--full-precision", @@ -780,14 +603,6 @@ def main(): choices=[x.value for x in ModelType], help="list installed models", ) - parser.add_argument( - "--config_file", - "-c", - dest="config_file", - type=str, - default=None, - help="path to configuration file to create", - ) parser.add_argument( "--root_dir", dest="root", diff --git a/invokeai/frontend/install/widgets.py b/invokeai/frontend/install/widgets.py index 06d5473fa3d..19d044ee85e 100644 --- a/invokeai/frontend/install/widgets.py +++ b/invokeai/frontend/install/widgets.py @@ -19,7 +19,7 @@ # minimum size for UIs MIN_COLS = 150 -MIN_LINES = 40 +MIN_LINES = 45 class WindowTooSmallException(Exception): @@ -264,6 +264,17 @@ def h_select(self, ch): self.on_changed(self.value) +class CheckboxWithChanged(npyscreen.Checkbox): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.on_changed = None + + def whenToggled(self): + super().whenToggled + if self.on_changed: + self.on_changed(self.value) + + class SingleSelectColumnsSimple(SelectColumnBase, SingleSelectWithChanged): """Row of radio buttons. Spacebar to select.""" diff --git a/invokeai/frontend/merge/merge_diffusers.py b/invokeai/frontend/merge/merge_diffusers.py index 8fa02cb49cc..dd42f18c098 100644 --- a/invokeai/frontend/merge/merge_diffusers.py +++ b/invokeai/frontend/merge/merge_diffusers.py @@ -6,21 +6,36 @@ """ import argparse import curses +import re import sys from argparse import Namespace from pathlib import Path -from typing import List, Optional +from typing import List, Optional, Tuple import npyscreen from npyscreen import widget import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_management import BaseModelType, ModelManager, ModelMerger, ModelType +from invokeai.backend.model_manager import ( + BaseModelType, + ModelConfigStore, + ModelFormat, + ModelType, + ModelVariantType, + get_config_store, +) +from invokeai.backend.model_manager.merge import ModelMerger from invokeai.frontend.install.widgets import FloatTitleSlider, SingleSelectColumns, TextBox config = InvokeAIAppConfig.get_config() +BASE_TYPES = [ + (BaseModelType.StableDiffusion1, "Models Built on SD-1.x"), + (BaseModelType.StableDiffusion2, "Models Built on SD-2.x"), + (BaseModelType.StableDiffusionXL, "Models Built on SDXL"), +] + def _parse_args() -> Namespace: parser = argparse.ArgumentParser(description="InvokeAI model merging") @@ -48,7 +63,7 @@ def _parse_args() -> Namespace: parser.add_argument( "--base_model", type=str, - choices=[x.value for x in BaseModelType], + choices=[x[0].value for x in BASE_TYPES], help="The base model shared by the models to be merged", ) parser.add_argument( @@ -106,9 +121,9 @@ def afterEditing(self): def create(self): window_height, window_width = curses.initscr().getmaxyx() - - self.model_names = self.get_model_names() self.current_base = 0 + self.models = self.get_models(BASE_TYPES[self.current_base][0]) + self.model_names = [x[1] for x in self.models] max_width = max([len(x) for x in self.model_names]) max_width += 6 horizontal_layout = max_width * 3 < window_width @@ -128,10 +143,7 @@ def create(self): self.nextrely += 1 self.base_select = self.add_widget_intelligent( SingleSelectColumns, - values=[ - "Models Built on SD-1.x", - "Models Built on SD-2.x", - ], + values=[x[1] for x in BASE_TYPES], value=[self.current_base], columns=4, max_height=2, @@ -262,19 +274,19 @@ def on_cancel(self): sys.exit(0) def marshall_arguments(self) -> dict: - model_names = self.model_names + model_keys = [x[0] for x in self.models] models = [ - model_names[self.model1.value[0]], - model_names[self.model2.value[0]], + model_keys[self.model1.value[0]], + model_keys[self.model2.value[0]], ] if self.model3.value[0] > 0: - models.append(model_names[self.model3.value[0] - 1]) + models.append(model_keys[self.model3.value[0] - 1]) interp = "add_difference" else: interp = self.interpolations[self.merge_method.value[0]] args = dict( - model_names=models, + model_keys=models, base_model=tuple(BaseModelType)[self.base_select.value[0]], alpha=self.alpha.value, interp=interp, @@ -309,17 +321,18 @@ def validate_field_values(self) -> bool: else: return True - def get_model_names(self, base_model: Optional[BaseModelType] = None) -> List[str]: - model_names = [ - info["model_name"] - for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model) - if info["model_format"] == "diffusers" + def get_models(self, base_model: Optional[BaseModelType] = None) -> List[Tuple[str, str]]: # key to name + models = [ + (x.key, x.name) + for x in self.model_manager.search_by_name(model_type=ModelType.Main, base_model=base_model) + if x.model_format == ModelFormat("diffusers") and x.variant == ModelVariantType("normal") ] - return sorted(model_names) + return sorted(models, key=lambda x: x[1]) - def _populate_models(self, value=None): - base_model = tuple(BaseModelType)[value[0]] - self.model_names = self.get_model_names(base_model) + def _populate_models(self, value: List[int]): + base_model = BASE_TYPES[value[0]][0] + self.models = self.get_models(base_model) + self.model_names = [x[1] for x in self.models] models_plus_none = self.model_names.copy() models_plus_none.insert(0, "None") @@ -331,7 +344,7 @@ def _populate_models(self, value=None): class Mergeapp(npyscreen.NPSAppManaged): - def __init__(self, model_manager: ModelManager): + def __init__(self, model_manager: ModelConfigStore): super().__init__() self.model_manager = model_manager @@ -341,14 +354,13 @@ def onStart(self): def run_gui(args: Namespace): - model_manager = ModelManager(config.model_conf_path) + model_manager: ModelConfigStore = get_config_store(config.model_conf_path) mergeapp = Mergeapp(model_manager) mergeapp.run() - args = mergeapp.merge_arguments - merger = ModelMerger(model_manager) - merger.merge_diffusion_models_and_save(**args) - logger.info(f'Models merged into new model: "{args["merged_model_name"]}".') + merger = ModelMerger(model_manager, config) + merger.merge_diffusion_models_and_save(**vars(args)) + logger.info(f'Models merged into new model: "{args.merged_model_name}".') def run_cli(args: Namespace): @@ -361,13 +373,31 @@ def run_cli(args: Namespace): args.merged_model_name = "+".join(args.model_names) logger.info(f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"') - model_manager = ModelManager(config.model_conf_path) + model_manager: ModelConfigStore = get_config_store(config.model_conf_path) assert ( - not model_manager.model_exists(args.merged_model_name, args.base_model, ModelType.Main) or args.clobber + len(model_manager.search_by_name(args.merged_model_name, args.base_model, ModelType.Main)) == 0 or args.clobber ), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.' merger = ModelMerger(model_manager) - merger.merge_diffusion_models_and_save(**vars(args)) + model_keys = [] + for name in args.model_names: + if len(name) == 32 and re.match(r"^[0-9a-f]$", name): + model_keys.append(name) + else: + models = model_manager.search_by_name( + model_name=name, model_type=ModelType.Main, base_model=BaseModelType(args.base_model) + ) + assert len(models) > 0, f"{name}: Unknown model" + assert len(models) < 2, f"{name}: More than one model by this name. Please specify the model key instead." + model_keys.append(models[0].key) + + merger.merge_diffusion_models_and_save( + alpha=args.alpha, + model_keys=model_keys, + merged_model_name=args.merged_model_name, + interp=args.interp, + force=args.force, + ) logger.info(f'Models merged into new model: "{args.merged_model_name}".') @@ -375,6 +405,8 @@ def main(): args = _parse_args() if args.root_dir: config.parse_args(["--root", str(args.root_dir)]) + else: + config.parse_args([]) try: if args.front_end: diff --git a/invokeai/frontend/training/textual_inversion.py b/invokeai/frontend/training/textual_inversion.py index f3911f7e0e2..7236511ddb5 100755 --- a/invokeai/frontend/training/textual_inversion.py +++ b/invokeai/frontend/training/textual_inversion.py @@ -22,6 +22,7 @@ import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend.model_manager import ModelConfigStore, ModelType, get_config_store from ...backend.training import do_textual_inversion_training, parse_args @@ -275,10 +276,13 @@ def validate_field_values(self) -> bool: return True def get_model_names(self) -> Tuple[List[str], int]: - conf = OmegaConf.load(config.root_dir / "configs/models.yaml") - model_names = [idx for idx in sorted(list(conf.keys())) if conf[idx].get("format", None) == "diffusers"] - defaults = [idx for idx in range(len(model_names)) if "default" in conf[model_names[idx]]] - default = defaults[0] if len(defaults) > 0 else 0 + global config + store: ModelConfigStore = get_config_store(config.model_conf_path) + main_models = store.search_by_name(model_type=ModelType.Main) + model_names = [ + f"{x.base_model.value}/{x.model_type.value}/{x.name}" for x in main_models if x.model_format == "diffusers" + ] + default = 0 return (model_names, default) def marshall_arguments(self) -> dict: @@ -384,6 +388,7 @@ def previous_args() -> dict: def do_front_end(args: Namespace): + global config saved_args = previous_args() myapplication = MyApplication(saved_args=saved_args) myapplication.run() @@ -399,7 +404,7 @@ def do_front_end(args: Namespace): save_args(args) try: - do_textual_inversion_training(InvokeAIAppConfig.get_config(), **args) + do_textual_inversion_training(config, **args) copy_to_embeddings_folder(args) except Exception as e: logger.error("An exception occurred during training. The exception was:") @@ -413,6 +418,7 @@ def main(): args = parse_args() config = InvokeAIAppConfig.get_config() + config.parse_args([]) # change root if needed if args.root_dir: diff --git a/invokeai/frontend/web/.gitignore b/invokeai/frontend/web/.gitignore index cacf107e1bb..306338ca1d9 100644 --- a/invokeai/frontend/web/.gitignore +++ b/invokeai/frontend/web/.gitignore @@ -35,6 +35,7 @@ stats.html !.yarn/releases !.yarn/sdks !.yarn/versions +.vite # Yalc .yalc diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx index a77d58d07fb..9d7153e2dbc 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx @@ -238,7 +238,7 @@ const modelsFilter = < T extends | MainModelConfigEntity | LoRAModelConfigEntity - | OnnxModelConfigEntity, + | OnnxModelConfigEntity >( data: EntityState | undefined, model_type: ModelType, diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index e095bce8cab..4da61c79109 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -243,7 +243,6 @@ export const modelsApi = api.injectEndpoints({ { type: 'MainModel', id: LIST_TAG }, 'Model', ]; - if (result) { tags.push( ...result.ids.map((id) => ({ diff --git a/pyproject.toml b/pyproject.toml index f4bbf011021..643d415debd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ dependencies = [ "fastapi==0.88.0", "fastapi-events==0.8.0", "huggingface-hub~=0.16.4", + "imohash~=1.0.0", "invisible-watermark~=0.2.0", # needed to install SDXL base and refiner using their repo_ids "matplotlib", # needed for plotting of Penner easing functions "mediapipe", # needed for "mediapipeface" controlnet model @@ -106,6 +107,7 @@ dependencies = [ "pytest>6.0.0", "pytest-cov", "pytest-datadir", + "requests-testadapter", ] "xformers" = [ "xformers~=0.0.19; sys_platform!='darwin'", @@ -140,7 +142,6 @@ dependencies = [ "invokeai-merge" = "invokeai.frontend.merge:invokeai_merge_diffusers" "invokeai-ti" = "invokeai.frontend.training:invokeai_textual_inversion" "invokeai-model-install" = "invokeai.frontend.install.model_install:main" -"invokeai-migrate3" = "invokeai.backend.install.migrate_to_3:main" "invokeai-update" = "invokeai.frontend.install.invokeai_update:main" "invokeai-metadata" = "invokeai.backend.image_util.invoke_metadata:main" "invokeai-node-cli" = "invokeai.app.cli_app:invoke_cli" diff --git a/scripts/convert_models_config_to_3.2.py b/scripts/convert_models_config_to_3.2.py new file mode 100644 index 00000000000..7ce6f357441 --- /dev/null +++ b/scripts/convert_models_config_to_3.2.py @@ -0,0 +1,39 @@ +# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team + +""" +convert_models_config_to_3.2.py. + +This script converts a pre-3.2 models.yaml file into the 3.2 format. +The main difference is that each model is identified by a unique hash, +rather than the concatenation of base, type and name used previously. + +In addition, there are more metadata fields attached to each model. +These will mostly be empty after conversion, but will be populated +when new models are downloaded from HuggingFace or Civitae. +""" +import argparse +from pathlib import Path + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend.model_manager.storage import migrate_models_store + + +def main(): + parser = argparse.ArgumentParser(description="Convert a pre-3.2 models.yaml into the 3.2 version.") + parser.add_argument("--root", type=Path, help="Alternate root directory containing the models.yaml to convert") + parser.add_argument( + "--outfile", + type=Path, + default=Path("./models-3.2.yaml"), + help="File to write to. A file with suffix '.yaml' will use the YAML format. A file with an extension of '.db' will be treated as a SQLite3 database.", + ) + args = parser.parse_args() + config_args = ["--root", args.root.as_posix()] if args.root else [] + + config = InvokeAIAppConfig.get_config() + config.parse_args(config_args) + migrate_models_store(config) + + +if __name__ == "__main__": + main() diff --git a/scripts/probe-model.py b/scripts/probe-model.py index 05741de806d..cf7ce39d4af 100755 --- a/scripts/probe-model.py +++ b/scripts/probe-model.py @@ -1,9 +1,19 @@ #!/bin/env python +"""Little command-line utility for probing a model on disk.""" + import argparse +import json +import sys from pathlib import Path -from invokeai.backend.model_management.model_probe import ModelProbe +from invokeai.backend.model_manager import InvalidModelException, ModelProbe, SchedulerPredictionType + + +def helper(model_path: Path): + print('Warning: guessing "v_prediction" SchedulerPredictionType', file=sys.stderr) + return SchedulerPredictionType.VPrediction + parser = argparse.ArgumentParser(description="Probe model type") parser.add_argument( @@ -14,5 +24,8 @@ args = parser.parse_args() for path in args.model_path: - info = ModelProbe().probe(path) - print(f"{path}: {info}") + try: + info = ModelProbe.probe(path, helper) + print(f"{path}:{json.dumps(info.dict(), sort_keys=True, indent=4)}") + except InvalidModelException as exc: + print(exc) diff --git a/tests/app/__init__.py b/tests/AA_nodes/__init__.py similarity index 100% rename from tests/app/__init__.py rename to tests/AA_nodes/__init__.py diff --git a/tests/nodes/test_graph_execution_state.py b/tests/AA_nodes/test_graph_execution_state.py similarity index 98% rename from tests/nodes/test_graph_execution_state.py rename to tests/AA_nodes/test_graph_execution_state.py index e43075bd32c..5ca809163d2 100644 --- a/tests/nodes/test_graph_execution_state.py +++ b/tests/AA_nodes/test_graph_execution_state.py @@ -49,7 +49,10 @@ def mock_services() -> InvocationServices: conn=db_conn, table_name="graph_executions", lock=lock ) return InvocationServices( - model_manager=None, # type: ignore + download_queue=None, # type: ignore + model_loader=None, # type: ignore + model_installer=None, # type: ignore + model_record_store=None, # type: ignore events=TestEventService(), logger=logging, # type: ignore images=None, # type: ignore diff --git a/tests/nodes/test_invoker.py b/tests/AA_nodes/test_invoker.py similarity index 97% rename from tests/nodes/test_invoker.py rename to tests/AA_nodes/test_invoker.py index 7c636c3ecad..af03e1f6f81 100644 --- a/tests/nodes/test_invoker.py +++ b/tests/AA_nodes/test_invoker.py @@ -59,7 +59,10 @@ def mock_services() -> InvocationServices: conn=db_conn, table_name="graph_executions", lock=lock ) return InvocationServices( - model_manager=None, # type: ignore + download_queue=None, # type: ignore + model_loader=None, # type: ignore + model_installer=None, # type: ignore + model_record_store=None, # type: ignore events=TestEventService(), logger=logging, # type: ignore images=None, # type: ignore diff --git a/tests/nodes/test_node_graph.py b/tests/AA_nodes/test_node_graph.py similarity index 100% rename from tests/nodes/test_node_graph.py rename to tests/AA_nodes/test_node_graph.py diff --git a/tests/nodes/test_nodes.py b/tests/AA_nodes/test_nodes.py similarity index 100% rename from tests/nodes/test_nodes.py rename to tests/AA_nodes/test_nodes.py diff --git a/tests/nodes/test_session_queue.py b/tests/AA_nodes/test_session_queue.py similarity index 99% rename from tests/nodes/test_session_queue.py rename to tests/AA_nodes/test_session_queue.py index f28ec1ac540..353615d7d34 100644 --- a/tests/nodes/test_session_queue.py +++ b/tests/AA_nodes/test_session_queue.py @@ -12,7 +12,8 @@ populate_graph, prepare_values_to_insert, ) -from tests.nodes.test_nodes import PromptTestInvocation + +from .test_nodes import PromptTestInvocation @pytest.fixture diff --git a/tests/nodes/test_sqlite.py b/tests/AA_nodes/test_sqlite.py similarity index 100% rename from tests/nodes/test_sqlite.py rename to tests/AA_nodes/test_sqlite.py diff --git a/tests/test_config.py b/tests/AB_config/test_config.py similarity index 94% rename from tests/test_config.py rename to tests/AB_config/test_config.py index 2b2492f6a6e..7bf573b8e9a 100644 --- a/tests/test_config.py +++ b/tests/AB_config/test_config.py @@ -6,6 +6,8 @@ from omegaconf import OmegaConf from pydantic import ValidationError +from invokeai.app.services.config import InvokeAIAppConfig + @pytest.fixture def patch_rootdir(tmp_path: Path, monkeypatch: Any) -> None: @@ -55,7 +57,6 @@ def test_use_init(patch_rootdir): # note that we explicitly set omegaconf dict and argv here # so that the values aren't read from ~invokeai/invokeai.yaml and # sys.argv respectively. - from invokeai.app.services.config import InvokeAIAppConfig conf1 = InvokeAIAppConfig.get_config() assert conf1 @@ -73,8 +74,6 @@ def test_use_init(patch_rootdir): def test_legacy(): - from invokeai.app.services.config import InvokeAIAppConfig - conf = InvokeAIAppConfig.get_config() assert conf conf.parse_args(conf=init3, argv=[]) @@ -86,8 +85,6 @@ def test_legacy(): def test_argv_override(): - from invokeai.app.services.config import InvokeAIAppConfig - conf = InvokeAIAppConfig.get_config() conf.parse_args(conf=init1, argv=["--always_use_cpu", "--max_cache=10"]) assert conf.always_use_cpu @@ -96,8 +93,6 @@ def test_argv_override(): def test_env_override(patch_rootdir): - from invokeai.app.services.config import InvokeAIAppConfig - # argv overrides conf = InvokeAIAppConfig() conf.parse_args(conf=init1, argv=["--max_cache=10"]) @@ -129,8 +124,6 @@ def test_env_override(patch_rootdir): def test_root_resists_cwd(patch_rootdir): - from invokeai.app.services.config import InvokeAIAppConfig - previous = os.environ["INVOKEAI_ROOT"] cwd = Path(os.getcwd()).resolve() @@ -146,8 +139,6 @@ def test_root_resists_cwd(patch_rootdir): def test_type_coercion(patch_rootdir): - from invokeai.app.services.config import InvokeAIAppConfig - conf = InvokeAIAppConfig().get_config() conf.parse_args(argv=["--root=/tmp/foobar"]) assert conf.root == Path("/tmp/foobar") diff --git a/tests/AB_config/test_model_config2.py b/tests/AB_config/test_model_config2.py new file mode 100644 index 00000000000..589f0c9ffa9 --- /dev/null +++ b/tests/AB_config/test_model_config2.py @@ -0,0 +1,166 @@ +""" +Test the refactored model config classes. +""" + +from invokeai.backend.model_manager.config import ( + InvalidModelConfigException, + LoRAConfig, + MainCheckpointConfig, + MainDiffusersConfig, + ModelConfigFactory, + ONNXSD1Config, + ONNXSD2Config, + TextualInversionConfig, + ValidationError, +) + + +def test_checkpoints(): + raw = dict( + path="/tmp/foo.ckpt", + name="foo", + base_model="sd-1", + model_type="main", + config="/tmp/foo.yaml", + variant="normal", + model_format="checkpoint", + ) + config = ModelConfigFactory.make_config(raw) + assert isinstance(config, MainCheckpointConfig) + assert config.model_format == "checkpoint" + assert config.base_model == "sd-1" + assert config.vae is None + + +def test_diffusers(): + raw = dict( + path="/tmp/foo", + name="foo", + base_model="sd-2", + model_type="main", + variant="inpaint", + model_format="diffusers", + vae="/tmp/foobar/vae.pt", + ) + config = ModelConfigFactory.make_config(raw) + assert isinstance(config, MainDiffusersConfig) + assert config.model_format == "diffusers" + assert config.base_model == "sd-2" + assert config.variant == "inpaint" + assert config.vae == "/tmp/foobar/vae.pt" + + +def test_invalid_diffusers(): + raw = dict( + path="/tmp/foo", + name="foo", + base_model="sd-2", + model_type="main", + variant="inpaint", + config="/tmp/foo.ckpt", + model_format="diffusers", + ) + # This is expected to fail with a validation error, because + # diffusers format does not have a `config` field + try: + ModelConfigFactory.make_config(raw) + assert False, "Validation should have failed" + except InvalidModelConfigException: + assert True + + +def test_lora(): + raw = dict( + path="/tmp/foo", + name="foo", + base_model="sdxl", + model_type="lora", + model_format="lycoris", + ) + config = ModelConfigFactory.make_config(raw) + assert isinstance(config, LoRAConfig) + assert config.model_format == "lycoris" + raw["model_format"] = "diffusers" + config = ModelConfigFactory.make_config(raw) + assert isinstance(config, LoRAConfig) + assert config.model_format == "diffusers" + + +def test_embedding(): + raw = dict( + path="/tmp/foo", + name="foo", + base_model="sdxl-refiner", + model_type="embedding", + model_format="embedding_file", + ) + config = ModelConfigFactory.make_config(raw) + assert isinstance(config, TextualInversionConfig) + assert config.model_format == "embedding_file" + + +def test_onnx(): + raw = dict( + path="/tmp/foo.ckpt", + name="foo", + base_model="sd-1", + model_type="onnx", + variant="normal", + model_format="onnx", + ) + config = ModelConfigFactory.make_config(raw) + assert isinstance(config, ONNXSD1Config) + assert config.model_format == "onnx" + + raw["base_model"] = "sd-2" + # this should not validate without the upcast_attention field + try: + ModelConfigFactory.make_config(raw) + assert False, "Config should not have validated without upcast_attention" + except InvalidModelConfigException: + assert True + + raw["upcast_attention"] = True + raw["prediction_type"] = "epsilon" + config = ModelConfigFactory.make_config(raw) + assert isinstance(config, ONNXSD2Config) + assert config.upcast_attention + + +def test_assignment(): + raw = dict( + path="/tmp/foo.ckpt", + name="foo", + base_model="sd-2", + model_type="onnx", + variant="normal", + model_format="onnx", + upcast_attention=True, + prediction_type="epsilon", + ) + config = ModelConfigFactory.make_config(raw) + config.upcast_attention = False + assert not config.upcast_attention + try: + config.prediction_type = "not valid" + assert False, "Config should not have accepted invalid assignment" + except ValidationError: + assert True + + +def test_invalid_combination(): + raw = dict( + path="/tmp/foo.ckpt", + name="foo", + base_model="sd-2", + model_type="main", + variant="normal", + model_format="onnx", + upcast_attention=True, + prediction_type="epsilon", + ) + try: + ModelConfigFactory.make_config(raw) + assert False, "This should have raised an InvalidModelConfigException" + except InvalidModelConfigException: + assert True diff --git a/tests/test_path.py b/tests/AB_config/test_path.py similarity index 100% rename from tests/test_path.py rename to tests/AB_config/test_path.py diff --git a/tests/AC_model_manager/README.txt b/tests/AC_model_manager/README.txt new file mode 100644 index 00000000000..7a19d4c1525 --- /dev/null +++ b/tests/AC_model_manager/README.txt @@ -0,0 +1,8 @@ +These tests are placed in an "x_" folder so that they are run after +the node tests. If they run beforehand the nodes tests blow up. I +suspect that there are conflicts arising from the in-memory +InvokeAIAppConfig object, but even when I take care to create a fresh +object each time, the problem persists, so perhaps it is something +else? + +- Lincoln diff --git a/tests/backend/model_management/test_libc_util.py b/tests/AC_model_manager/test_libc_util.py similarity index 86% rename from tests/backend/model_management/test_libc_util.py rename to tests/AC_model_manager/test_libc_util.py index a517db4c903..cddb5d025f9 100644 --- a/tests/backend/model_management/test_libc_util.py +++ b/tests/AC_model_manager/test_libc_util.py @@ -1,6 +1,6 @@ import pytest -from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2 +from invokeai.backend.model_manager.libc_util import LibcUtil, Struct_mallinfo2 def test_libc_util_mallinfo2(): diff --git a/tests/backend/model_management/test_memory_snapshot.py b/tests/AC_model_manager/test_memory_snapshot.py similarity index 87% rename from tests/backend/model_management/test_memory_snapshot.py rename to tests/AC_model_manager/test_memory_snapshot.py index 80aed7b7ba5..bb9e71f903f 100644 --- a/tests/backend/model_management/test_memory_snapshot.py +++ b/tests/AC_model_manager/test_memory_snapshot.py @@ -1,7 +1,7 @@ import pytest -from invokeai.backend.model_management.libc_util import Struct_mallinfo2 -from invokeai.backend.model_management.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff +from invokeai.backend.model_manager.libc_util import Struct_mallinfo2 +from invokeai.backend.model_manager.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff def test_memory_snapshot_capture(): diff --git a/tests/AC_model_manager/test_model_download.py b/tests/AC_model_manager/test_model_download.py new file mode 100644 index 00000000000..8fed2f71c8f --- /dev/null +++ b/tests/AC_model_manager/test_model_download.py @@ -0,0 +1,368 @@ +"""Test the queued download facility""" + +import tempfile +import time +from pathlib import Path + +import requests +from requests import HTTPError +from requests_testadapter import TestAdapter + +import invokeai.backend.model_manager.download.model_queue as download_queue +from invokeai.backend.model_manager.download import ( + DownloadJobBase, + DownloadJobStatus, + ModelDownloadQueue, + UnknownJobIDException, +) + +# Allow for at least one chunk to be fetched during the pause/unpause test. +# Otherwise pause test doesn't work because whole file contents are read +# before pause is received. +download_queue.DOWNLOAD_CHUNK_SIZE = 16500 + +# Prevent pytest deprecation warnings +TestAdapter.__test__ = False + +# Disable some tests that require the internet. +INTERNET_AVAILABLE = requests.get("http://www.google.com/").status_code == 200 + +######################################################################################## +# Lots of dummy content here to test model download without using lots of bandwidth +# The repo_id tests are not self-contained because they still need to use the HF API +# to retrieve metainformation about the files to retrieve. However, the big weights files +# are not downloaded. + +# If the internet is not available, then the repo_id tests are skipped, but the single +# URL tests are still run. + +session = requests.Session() +for i in ["12345", "9999", "54321"]: + content = ( + b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000) + ) # for pause tests, must make content large + session.mount( + f"http://www.civitai.com/models/{i}", + TestAdapter( + content, + headers={ + "Content-Length": len(content), + "Content-Disposition": f'filename="mock{i}.safetensors"', + }, + ), + ) + +# here are some malformed URLs to test +# missing the content length +session.mount( + "http://www.civitai.com/models/missing", + TestAdapter( + b"Missing content length", + headers={ + "Content-Disposition": 'filename="missing.txt"', + }, + ), +) +# not found test +session.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404)) + +# prevent us from going to civitai to get metadata +session.mount("https://civitai.com/api/download/models/", TestAdapter(b"Not found", status=404)) +session.mount("https://civitai.com/api/v1/models/", TestAdapter(b"Not found", status=404)) +session.mount("https://civitai.com/api/v1/model-versions/", TestAdapter(b"Not found", status=404)) + +# specifies a content disposition that may overwrite files in the parent directory +session.mount( + "http://www.civitai.com/models/malicious", + TestAdapter( + b"Malicious URL", + headers={ + "Content-Disposition": 'filename="../badness.txt"', + }, + ), +) +# Would create a path that is too long +session.mount( + "http://www.civitai.com/models/long", + TestAdapter( + b"Malicious URL", + headers={ + "Content-Disposition": f'filename="{"i"*1000}"', + }, + ), +) + +# mock HuggingFace URLs +hf_sd2_paths = [ + "feature_extractor/preprocessor_config.json", + "scheduler/scheduler_config.json", + "text_encoder/config.json", + "text_encoder/model.fp16.safetensors", + "text_encoder/model.safetensors", + "text_encoder/pytorch_model.fp16.bin", + "text_encoder/pytorch_model.bin", + "tokenizer/merges.txt", + "tokenizer/special_tokens_map.json", + "tokenizer/tokenizer_config.json", + "tokenizer/vocab.json", + "unet/config.json", + "unet/diffusion_pytorch_model.fp16.safetensors", + "unet/diffusion_pytorch_model.safetensors", + "vae/config.json", + "vae/diffusion_pytorch_model.fp16.safetensors", + "vae/diffusion_pytorch_model.safetensors", +] +for path in hf_sd2_paths: + url = f"https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/{path}" + path = Path(path).as_posix() + filename = Path(path).name + content = b"This is the content for path " + bytearray(path, "utf-8") + session.mount( + url, + TestAdapter( + content, + status=200, + headers={"Content-Length": len(content), "Content-Disposition": f'filename="{filename}"'}, + ), + ) + +# This is the content of `model_index.json` for stable-diffusion-2-1 +model_index_content = b'{"_class_name": "StableDiffusionPipeline", "_diffusers_version": "0.8.0", "feature_extractor": ["transformers", "CLIPImageProcessor"], "requires_safety_checker": false, "safety_checker": [null, null], "scheduler": ["diffusers", "DDIMScheduler"], "text_encoder": ["transformers", "CLIPTextModel"], "tokenizer": ["transformers", "CLIPTokenizer"], "unet": ["diffusers", "UNet2DConditionModel"], "vae": ["diffusers", "AutoencoderKL"]}' + +session.mount( + "https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/model_index.json", + TestAdapter( + model_index_content, + status=200, + headers={"Content-Length": len(model_index_content), "Content-Disposition": 'filename="model_index.json"'}, + ), +) + +# ================================================================================================================== # + + +def test_basic_queue_download(): + events = list() + + def event_handler(job: DownloadJobBase): + events.append(job.status) + + queue = ModelDownloadQueue( + requests_session=session, + event_handlers=[event_handler], + ) + with tempfile.TemporaryDirectory() as tmpdir: + job = queue.create_download_job(source="http://www.civitai.com/models/12345", destdir=tmpdir, start=False) + assert isinstance(job, DownloadJobBase), "expected the job to be of type DownloadJobBase" + assert isinstance(job.id, int), "expected the job id to be numeric" + assert job.status == "idle", "expected job status to be idle" + assert job.status == DownloadJobStatus.IDLE + + queue.start_job(job) + queue.join() + assert events[0] == DownloadJobStatus.ENQUEUED + assert events[-1] == DownloadJobStatus.COMPLETED + assert DownloadJobStatus.RUNNING in events + assert Path(tmpdir, "mock12345.safetensors").exists(), f"expected {tmpdir}/mock12345.safetensors to exist" + + +def test_queue_priority(): + queue = ModelDownloadQueue( + requests_session=session, + ) + + with tempfile.TemporaryDirectory() as tmpdir: + job1 = queue.create_download_job( + priority=0, source="http://www.civitai.com/models/12345", destdir=tmpdir, start=False + ) + job2 = queue.create_download_job( + priority=10, source="http://www.civitai.com/models/9999", destdir=tmpdir, start=False + ) + + assert job1 < job2 + + queue.start_all_jobs() + queue.join() + assert job1.job_sequence < job2.job_sequence + + job1 = queue.create_download_job( + priority=10, source="http://www.civitai.com/models/12345", destdir=tmpdir, start=False + ) + job2 = queue.create_download_job( + priority=0, source="http://www.civitai.com/models/9999", destdir=tmpdir, start=False + ) + + assert job2 < job1 + + queue.start_all_jobs() + queue.join() + assert job2.job_sequence < job1.job_sequence + + assert Path(tmpdir, "mock12345.safetensors").exists(), f"expected {tmpdir}/mock12345.safetensors to exist" + assert Path(tmpdir, "mock9999.safetensors").exists(), f"expected {tmpdir}/mock9999.safetensors to exist" + + +def test_repo_id_download(): + if not INTERNET_AVAILABLE: + return + repo_id = "stabilityai/stable-diffusion-2-1" + queue = ModelDownloadQueue( + requests_session=session, + ) + + # first with fp16 variant + with tempfile.TemporaryDirectory() as tmpdir: + queue.create_download_job(source=repo_id, destdir=tmpdir, variant="fp16", start=True) + queue.join() + repo_root = Path(tmpdir, "stable-diffusion-2-1") + assert repo_root.exists() + assert Path(repo_root, "model_index.json").exists() + assert Path(repo_root, "text_encoder", "config.json").exists() + assert Path(repo_root, "text_encoder", "model.fp16.safetensors").exists() + + # then without fp16 + with tempfile.TemporaryDirectory() as tmpdir: + queue.create_download_job(source=repo_id, destdir=tmpdir, start=True) + queue.join() + repo_root = Path(tmpdir, "stable-diffusion-2-1") + assert Path(repo_root, "text_encoder", "model.safetensors").exists() + assert not Path(repo_root, "text_encoder", "model.fp16.safetensors").exists() + + +def test_bad_urls(): + queue = ModelDownloadQueue( + requests_session=session, + ) + + with tempfile.TemporaryDirectory() as tmpdir: + # do we handle 404 and other HTTP errors? + job = queue.create_download_job(source="http://www.civitai.com/models/broken", destdir=tmpdir) + queue.join() + assert job.status == "error" + assert isinstance(job.error, HTTPError) + assert str(job.error) == "NOT FOUND" + + # Do we handle missing content length field? + job = queue.create_download_job(source="http://www.civitai.com/models/missing", destdir=tmpdir) + queue.join() + assert job.status == "completed" + assert job.total_bytes == 0 + assert job.bytes > 0 + assert job.bytes == Path(tmpdir, "missing.txt").stat().st_size + + # Don't let the URL specify a filename with slashes or double dots... (e.g. '../../etc/passwd') + job = queue.create_download_job(source="http://www.civitai.com/models/malicious", destdir=tmpdir) + queue.join() + assert job.status == "completed" + assert job.destination == Path(tmpdir, "malicious") + assert Path(tmpdir, "malicious").exists() + + # Nor a destination that would exceed the maximum filename or path length + job = queue.create_download_job(source="http://www.civitai.com/models/long", destdir=tmpdir) + queue.join() + assert job.status == "completed" + assert job.destination == Path(tmpdir, "long") + assert Path(tmpdir, "long").exists() + + # create a foreign job which will be invalid for the queue + bad_job = DownloadJobBase(id=999, source="mock", destination="mock") + try: + queue.start_job(bad_job) # this should fail + succeeded = True + except UnknownJobIDException: + succeeded = False + assert not succeeded + + +def test_pause_cancel_url(): # this one is tricky because of potential race conditions + def event_handler(job: DownloadJobBase): + time.sleep(0.5) # slow down the thread so that we can recover the paused state + + queue = ModelDownloadQueue(requests_session=session, event_handlers=[event_handler]) + with tempfile.TemporaryDirectory() as tmpdir: + job1 = queue.create_download_job(source="http://www.civitai.com/models/12345", destdir=tmpdir, start=False) + job2 = queue.create_download_job(source="http://www.civitai.com/models/9999", destdir=tmpdir, start=False) + job3 = queue.create_download_job(source="http://www.civitai.com/models/54321", destdir=tmpdir, start=False) + + assert job1.status == "idle" + queue.start_job(job1) + queue.start_job(job3) + time.sleep(0.1) # wait for enqueueing + assert job1.status in ["enqueued", "running"] + + # check pause and restart + queue.pause_job(job1) + time.sleep(0.1) # wait to be paused + assert job1.status == "paused" + + queue.start_job(job1) + time.sleep(0.1) + assert job1.status == "running" + + # check cancel + queue.start_job(job2) + time.sleep(0.1) + assert job2.status == "running" + queue.cancel_job(job2) + time.sleep(0.1) + assert job2.status == "cancelled" + + queue.join() + assert job1.status == "completed" + assert job2.status == "cancelled" + assert job3.status == "completed" + + assert Path(tmpdir, "mock12345.safetensors").exists() + assert Path(tmpdir, "mock9999.safetensors").exists() is False, "cancelled file should be deleted" + assert Path(tmpdir, "mock54321.safetensors").exists() + + assert len(queue.list_jobs()) == 3 + queue.prune_jobs() + assert len(queue.list_jobs()) == 0 + + +def test_pause_cancel_repo_id(): # this one is tricky because of potential race conditions + def event_handler(job: DownloadJobBase): + time.sleep(0.1) # slow down the thread by blocking it just a bit at every step + + if not INTERNET_AVAILABLE: + return + + repo_id = "stabilityai/stable-diffusion-2-1" + queue = ModelDownloadQueue(requests_session=session, event_handlers=[event_handler]) + + with tempfile.TemporaryDirectory() as tmpdir1, tempfile.TemporaryDirectory() as tmpdir2: + job1 = queue.create_download_job(source=repo_id, destdir=tmpdir1, variant="fp16", start=False) + job2 = queue.create_download_job(source=repo_id, destdir=tmpdir2, variant="fp16", start=False) + assert job1.status == "idle" + queue.start_job(job1) + time.sleep(0.1) # wait for enqueueing + assert job1.status in ["enqueued", "running"] + + # check pause and restart + queue.pause_job(job1) + time.sleep(0.1) # wait to be paused + assert job1.status == "paused" + + queue.start_job(job1) + time.sleep(0.5) + assert job1.status == "running" + + # check cancel + queue.start_job(job2) + time.sleep(0.1) + assert job2.status == "running" + queue.cancel_job(job2) + + queue.join() + assert job1.status == "completed" + assert job2.status == "cancelled" + + assert Path(tmpdir1, "stable-diffusion-2-1", "model_index.json").exists() + assert not Path( + tmpdir2, "stable-diffusion-2-1", "model_index.json" + ).exists(), "cancelled file should be deleted" + + assert len(queue.list_jobs()) == 2 + queue.prune_jobs() + assert len(queue.list_jobs()) == 0 diff --git a/tests/AC_model_manager/test_model_install_service.py b/tests/AC_model_manager/test_model_install_service.py new file mode 100644 index 00000000000..055f9eeeae6 --- /dev/null +++ b/tests/AC_model_manager/test_model_install_service.py @@ -0,0 +1,84 @@ +import tempfile +from pathlib import Path +from typing import Any + +from pydantic import BaseModel + +from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig +from invokeai.app.services.events import EventServiceBase +from invokeai.app.services.model_install_service import ModelInstallService +from invokeai.app.services.model_loader_service import ModelLoadService +from invokeai.app.services.model_record_service import ModelRecordServiceBase +from invokeai.backend.model_manager import BaseModelType, ModelType + +# This is a very little embedding model that we can use to test installation +TEST_MODEL = "test_embedding.safetensors" + + +class DummyEvent(BaseModel): + """Dummy Event to use with Dummy Event service.""" + + event_name: str + status: str + + +class DummyEventService(EventServiceBase): + """Dummy event service for testing.""" + + events: list + + def __init__(self): + super().__init__() + self.events = list() + + def dispatch(self, event_name: str, payload: Any) -> None: + """Dispatch an event by appending it to self.events.""" + self.events.append(DummyEvent(event_name=event_name, status=payload["job"].status)) + + +def test_install(datadir: Path): + """Test installation of an itty-bitty embedding.""" + # create a temporary root directory for install to target + with tempfile.TemporaryDirectory() as tmpdir: + tmp_path = Path(tmpdir) + (tmp_path / "models").mkdir() + (tmp_path / "configs").mkdir() + config = InvokeAIAppConfig( + root=tmp_path, + model_config_db=tmp_path / "configs" / "models.yaml", + models_dir=tmp_path / "models", + ) + + event_bus = DummyEventService() + mm_store = ModelRecordServiceBase.open(config) + mm_load = ModelLoadService(config, mm_store) + mm_install = ModelInstallService(config=config, store=mm_store, event_bus=event_bus) + + source = datadir / TEST_MODEL + mm_install.install_model(source) + id_map = mm_install.wait_for_installs() + print(id_map) + assert source in id_map, "model did not install; id_map empty" + assert id_map[source] is not None, "model did not install: source field empty" + + # test the events + assert len(event_bus.events) > 0, "no events received" + assert len(event_bus.events) == 3 + + event_names = set([x.event_name for x in event_bus.events]) + assert "model_event" in event_names + event_payloads = set([x.status for x in event_bus.events]) + assert "enqueued" in event_payloads + assert "running" in event_payloads + assert "completed" in event_payloads + + key = id_map[source] + model = mm_store.get_model(key) # may raise an exception here + assert Path(config.models_path / model.path).exists(), "generated path incorrect" + assert model.base_model == BaseModelType.StableDiffusion1, "probe of model base type didn't work" + assert model.model_type == ModelType.TextualInversion, "probe of model type didn't work" + + model_info = mm_load.get_model(key) + assert model_info, "model did not load" + with model_info as model: + assert model is not None, "model context not working" diff --git a/tests/AC_model_manager/test_model_install_service/test_embedding.safetensors b/tests/AC_model_manager/test_model_install_service/test_embedding.safetensors new file mode 100644 index 00000000000..ebd6a044b57 Binary files /dev/null and b/tests/AC_model_manager/test_model_install_service/test_embedding.safetensors differ diff --git a/tests/AC_model_manager/test_model_manager.py b/tests/AC_model_manager/test_model_manager.py new file mode 100644 index 00000000000..60de5bffd76 --- /dev/null +++ b/tests/AC_model_manager/test_model_manager.py @@ -0,0 +1,49 @@ +from pathlib import Path + +import pytest + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend import SubModelType +from invokeai.backend.model_manager.loader import ModelLoad + +BASIC_MODEL_NAME = "sdxl-base-1-0" +VAE_OVERRIDE_MODEL_NAME = "sdxl-base-with-custom-vae-1-0" +VAE_NULL_OVERRIDE_MODEL_NAME = "sdxl-base-with-empty-vae-1-0" + + +@pytest.fixture +def model_manager(datadir) -> ModelLoad: + config = InvokeAIAppConfig(root=datadir, model_config_db="configs/relative_sub.models.yaml") + return ModelLoad(config=config) + + +def test_get_model_names(model_manager: ModelLoad): + store = model_manager.store + names = [x.name for x in store.all_models()] + assert names[:2] == [BASIC_MODEL_NAME, VAE_OVERRIDE_MODEL_NAME] + + +def test_get_model_path_for_diffusers(model_manager: ModelLoad, datadir: Path): + models = model_manager.store.search_by_name(model_name=BASIC_MODEL_NAME) + assert len(models) == 1 + model_config = models[0] + top_model_path, is_override = model_manager._get_model_path(model_config) + expected_model_path = datadir / "models" / "sdxl" / "main" / "SDXL base 1_0" + assert top_model_path == expected_model_path + assert not is_override + + +def test_get_model_path_for_overridden_vae(model_manager: ModelLoad, datadir: Path): + models = model_manager.store.search_by_name(model_name=VAE_OVERRIDE_MODEL_NAME) + assert len(models) == 1 + model_config = models[0] + vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae) + expected_vae_path = datadir / "models" / "sdxl" / "vae" / "sdxl-vae-fp16-fix" + assert vae_model_path == expected_vae_path + assert is_override + + +def test_get_model_path_for_null_overridden_vae(model_manager: ModelLoad, datadir: Path): + model_config = model_manager.store.search_by_name(model_name=VAE_NULL_OVERRIDE_MODEL_NAME)[0] + vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae) + assert not is_override diff --git a/tests/AC_model_manager/test_model_manager/configs/relative_sub.models.yaml b/tests/AC_model_manager/test_model_manager/configs/relative_sub.models.yaml new file mode 100644 index 00000000000..71b05e93567 --- /dev/null +++ b/tests/AC_model_manager/test_model_manager/configs/relative_sub.models.yaml @@ -0,0 +1,30 @@ +__metadata__: + version: 3.2 +ed799245c762f6d0a9ddfd4e31fdb010: + name: sdxl-base-1-0 + path: sdxl/main/SDXL base 1_0 + base_model: sdxl + model_type: main + model_format: diffusers + variant: normal + description: SDXL base v1.0 + +fa78e05dbf51c540ff9256eb65446fd6: + name: sdxl-base-with-custom-vae-1-0 + path: sdxl/main/SDXL base 1_0 + base_model: sdxl + model_type: main + variant: normal + model_format: diffusers + description: SDXL with customized VAE + vae: sdxl/vae/sdxl-vae-fp16-fix/ + +8a79e05d9f51c5ffff9256eb65446fd6: + name: sdxl-base-with-empty-vae-1-0 + path: sdxl/main/SDXL base 1_0 + base_model: sdxl + model_type: main + variant: normal + model_format: diffusers + description: SDXL with customized VAE + vae: '' diff --git a/tests/test_model_manager/models/sdxl/main/SDXL base 1_0/model_index.json b/tests/AC_model_manager/test_model_manager/models/sdxl/main/SDXL base 1_0/model_index.json similarity index 100% rename from tests/test_model_manager/models/sdxl/main/SDXL base 1_0/model_index.json rename to tests/AC_model_manager/test_model_manager/models/sdxl/main/SDXL base 1_0/model_index.json diff --git a/tests/test_model_manager/models/sdxl/vae/sdxl-vae-fp16-fix/config.json b/tests/AC_model_manager/test_model_manager/models/sdxl/vae/sdxl-vae-fp16-fix/config.json similarity index 100% rename from tests/test_model_manager/models/sdxl/vae/sdxl-vae-fp16-fix/config.json rename to tests/AC_model_manager/test_model_manager/models/sdxl/vae/sdxl-vae-fp16-fix/config.json diff --git a/tests/test_model_probe.py b/tests/AC_model_manager/test_model_probe.py similarity index 89% rename from tests/test_model_probe.py rename to tests/AC_model_manager/test_model_probe.py index 248b7d602fd..d7ee4f52a4d 100644 --- a/tests/test_model_probe.py +++ b/tests/AC_model_manager/test_model_probe.py @@ -3,7 +3,7 @@ import pytest from invokeai.backend import BaseModelType -from invokeai.backend.model_management.model_probe import VaeFolderProbe +from invokeai.backend.model_manager.probe import VaeFolderProbe @pytest.mark.parametrize( diff --git a/tests/test_model_probe/vae/sd-vae-ft-mse/config.json b/tests/AC_model_manager/test_model_probe/vae/sd-vae-ft-mse/config.json similarity index 100% rename from tests/test_model_probe/vae/sd-vae-ft-mse/config.json rename to tests/AC_model_manager/test_model_probe/vae/sd-vae-ft-mse/config.json diff --git a/tests/test_model_probe/vae/sdxl-vae/config.json b/tests/AC_model_manager/test_model_probe/vae/sdxl-vae/config.json similarity index 100% rename from tests/test_model_probe/vae/sdxl-vae/config.json rename to tests/AC_model_manager/test_model_probe/vae/sdxl-vae/config.json diff --git a/tests/test_model_probe/vae/taesd/config.json b/tests/AC_model_manager/test_model_probe/vae/taesd/config.json similarity index 100% rename from tests/test_model_probe/vae/taesd/config.json rename to tests/AC_model_manager/test_model_probe/vae/taesd/config.json diff --git a/tests/test_model_probe/vae/taesdxl/config.json b/tests/AC_model_manager/test_model_probe/vae/taesdxl/config.json similarity index 100% rename from tests/test_model_probe/vae/taesdxl/config.json rename to tests/AC_model_manager/test_model_probe/vae/taesdxl/config.json diff --git a/tests/AC_model_manager/test_model_storage_file.py b/tests/AC_model_manager/test_model_storage_file.py new file mode 100644 index 00000000000..80f19776840 --- /dev/null +++ b/tests/AC_model_manager/test_model_storage_file.py @@ -0,0 +1,136 @@ +""" +Test the refactored model config classes. +""" + +from hashlib import sha256 + +import pytest + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.model_record_service import ( + ModelRecordServiceBase, + ModelRecordServiceFile, + UnknownModelException, +) +from invokeai.backend.model_manager.config import DiffusersConfig, ModelType, TextualInversionConfig, VaeDiffusersConfig + + +@pytest.fixture +def store(datadir) -> ModelRecordServiceBase: + InvokeAIAppConfig(root=datadir) + return ModelRecordServiceFile(datadir / "configs" / "models.yaml") + + +def example_config() -> TextualInversionConfig: + return TextualInversionConfig( + path="/tmp/pokemon.bin", + name="old name", + base_model="sd-1", + model_type="embedding", + model_format="embedding_file", + author="Anonymous", + ) + + +def test_add(store: ModelRecordServiceBase): + raw = dict( + path="/tmp/foo.ckpt", + name="model1", + base_model="sd-1", + model_type="main", + config="/tmp/foo.yaml", + variant="normal", + model_format="checkpoint", + ) + store.add_model("key1", raw) + config1 = store.get_model("key1") + assert config1 is not None + raw["name"] = "model2" + raw["base_model"] = "sd-2" + raw["model_format"] = "diffusers" + raw.pop("config") + store.add_model("key2", raw) + config2 = store.get_model("key2") + assert config1.name == "model1" + assert config2.name == "model2" + assert config1.base_model == "sd-1" + assert config2.base_model == "sd-2" + + +def test_update(store: ModelRecordServiceBase): + config = example_config() + store.add_model("key1", config) + config = store.get_model("key1") + assert config.name == "old name" + + config.name = "new name" + store.update_model("key1", config) + new_config = store.get_model("key1") + assert new_config.name == "new name" + + try: + store.update_model("unknown_key", config) + assert False, "expected UnknownModelException" + except UnknownModelException: + assert True + + +def test_delete(store: ModelRecordServiceBase): + config = example_config() + store.add_model("key1", config) + config = store.get_model("key1") + store.del_model("key1") + try: + config = store.get_model("key1") + assert False, "expected fetch of deleted model to raise exception" + except UnknownModelException: + assert True + + try: + store.del_model("unknown") + assert False, "expected delete of unknown model to raise exception" + except UnknownModelException: + assert True + + +def test_exists(store: ModelRecordServiceBase): + config = example_config() + store.add_model("key1", config) + assert store.exists("key1") + assert not store.exists("key2") + + +def test_filter(store: ModelRecordServiceBase): + config1 = DiffusersConfig( + path="/tmp/config1", name="config1", base_model="sd-1", model_type="main", tags=["sfw", "commercial", "fantasy"] + ) + config2 = DiffusersConfig( + path="/tmp/config2", name="config2", base_model="sd-1", model_type="main", tags=["sfw", "commercial"] + ) + config3 = VaeDiffusersConfig(path="/tmp/config3", name="config3", base_model="sd-1", model_type="vae", tags=["sfw"]) + for c in config1, config2, config3: + store.add_model(sha256(c.name.encode("utf-8")).hexdigest(), c) + matches = store.search_by_name(model_type="main") + assert len(matches) == 2 + assert matches[0].name in {"config1", "config2"} + + matches = store.search_by_name(model_type="vae") + assert len(matches) == 1 + assert matches[0].name == "config3" + assert matches[0].key == sha256("config3".encode("utf-8")).hexdigest() + assert isinstance(matches[0].model_type, ModelType) # This tests that we get proper enums back + + matches = store.search_by_tag(["sfw"]) + assert len(matches) == 3 + + matches = store.search_by_tag(["sfw", "commercial"]) + assert len(matches) == 2 + + matches = store.search_by_tag(["sfw", "commercial", "fantasy"]) + assert len(matches) == 1 + + matches = store.search_by_tag(["sfw", "commercial", "fantasy", "nonsense"]) + assert len(matches) == 0 + + matches = store.all_models() + assert len(matches) == 3 diff --git a/tests/AC_model_manager/test_model_storage_sql.py b/tests/AC_model_manager/test_model_storage_sql.py new file mode 100644 index 00000000000..93b3969c88f --- /dev/null +++ b/tests/AC_model_manager/test_model_storage_sql.py @@ -0,0 +1,140 @@ +""" +Test the refactored model config classes. +""" + +import sys +from hashlib import sha256 + +import pytest + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.model_record_service import ( + ModelRecordServiceBase, + ModelRecordServiceSQL, + UnknownModelException, +) +from invokeai.backend.model_manager.config import DiffusersConfig, ModelType, TextualInversionConfig, VaeDiffusersConfig + + +@pytest.fixture +def store(datadir) -> ModelRecordServiceBase: + InvokeAIAppConfig(root=datadir) + return ModelRecordServiceSQL.from_db_file(datadir / "databases" / "models.db") + + +def example_config() -> TextualInversionConfig: + return TextualInversionConfig( + path="/tmp/pokemon.bin", + name="old name", + base_model="sd-1", + model_type="embedding", + model_format="embedding_file", + author="Anonymous", + ) + + +def test_add(store: ModelRecordServiceBase): + raw = dict( + path="/tmp/foo.ckpt", + name="model1", + base_model="sd-1", + model_type="main", + config="/tmp/foo.yaml", + variant="normal", + model_format="checkpoint", + ) + store.add_model("key1", raw) + config1 = store.get_model("key1") + assert config1 is not None + raw["name"] = "model2" + raw["base_model"] = "sd-2" + raw["model_format"] = "diffusers" + raw.pop("config") + store.add_model("key2", raw) + config2 = store.get_model("key2") + assert config1.name == "model1" + assert config2.name == "model2" + assert config1.base_model == "sd-1" + assert config2.base_model == "sd-2" + + +def test_update(store: ModelRecordServiceBase): + config = example_config() + store.add_model("key1", config) + config = store.get_model("key1") + assert config.name == "old name" + + config.name = "new name" + store.update_model("key1", config) + new_config = store.get_model("key1") + assert new_config.name == "new name" + + try: + store.update_model("unknown_key", config) + assert False, "expected UnknownModelException" + except UnknownModelException: + assert True + + +def test_delete(store: ModelRecordServiceBase): + config = example_config() + store.add_model("key1", config) + config = store.get_model("key1") + store.del_model("key1") + try: + config = store.get_model("key1") + assert False, "expected fetch of deleted model to raise exception" + except UnknownModelException: + assert True + + # a bug in sqlite3 in python 3.9 prevents DEL from returning number of + # deleted rows! + if sys.version_info.major == 3 and sys.version_info.minor > 9: + try: + store.del_model("unknown") + assert False, "expected delete of unknown model to raise exception" + except UnknownModelException: + assert True + + +def test_exists(store: ModelRecordServiceBase): + config = example_config() + store.add_model("key1", config) + assert store.exists("key1") + assert not store.exists("key2") + + +def test_filter(store: ModelRecordServiceBase): + config1 = DiffusersConfig( + path="/tmp/config1", name="config1", base_model="sd-1", model_type="main", tags=["sfw", "commercial", "fantasy"] + ) + config2 = DiffusersConfig( + path="/tmp/config2", name="config2", base_model="sd-1", model_type="main", tags=["sfw", "commercial"] + ) + config3 = VaeDiffusersConfig(path="/tmp/config3", name="config3", base_model="sd-1", model_type="vae", tags=["sfw"]) + for c in config1, config2, config3: + store.add_model(sha256(c.name.encode("utf-8")).hexdigest(), c) + matches = store.search_by_name(model_type="main") + assert len(matches) == 2 + assert matches[0].name in {"config1", "config2"} + + matches = store.search_by_name(model_type="vae") + assert len(matches) == 1 + assert matches[0].name == "config3" + assert matches[0].key == sha256("config3".encode("utf-8")).hexdigest() + assert isinstance(matches[0].model_type, ModelType) # This tests that we get proper enums back + + matches = store.search_by_tag(["sfw"]) + assert len(matches) == 3 + + matches = store.search_by_tag(["sfw", "commercial"]) + assert len(matches) == 2 + + matches = store.search_by_tag(["sfw", "commercial", "fantasy"]) + assert len(matches) == 1 + + matches = store.search_by_tag(["sfw", "commercial", "fantasy", "nonsense"]) + assert len(matches) == 0 + + matches = store.all_models() + assert len(matches) == 3 diff --git a/tests/app/util/__init__.py b/tests/AE_other_backend/__init__.py similarity index 100% rename from tests/app/util/__init__.py rename to tests/AE_other_backend/__init__.py diff --git a/tests/inpainting/coyote-inpainting.prompt b/tests/AE_other_backend/inpainting/coyote-inpainting.prompt similarity index 100% rename from tests/inpainting/coyote-inpainting.prompt rename to tests/AE_other_backend/inpainting/coyote-inpainting.prompt diff --git a/tests/inpainting/coyote-input.webp b/tests/AE_other_backend/inpainting/coyote-input.webp similarity index 100% rename from tests/inpainting/coyote-input.webp rename to tests/AE_other_backend/inpainting/coyote-input.webp diff --git a/tests/inpainting/coyote-mask.webp b/tests/AE_other_backend/inpainting/coyote-mask.webp similarity index 100% rename from tests/inpainting/coyote-mask.webp rename to tests/AE_other_backend/inpainting/coyote-mask.webp diff --git a/tests/inpainting/original.json b/tests/AE_other_backend/inpainting/original.json similarity index 100% rename from tests/inpainting/original.json rename to tests/AE_other_backend/inpainting/original.json diff --git a/tests/backend/__init__.py b/tests/AE_other_backend/ip_adapter/__init__.py similarity index 100% rename from tests/backend/__init__.py rename to tests/AE_other_backend/ip_adapter/__init__.py diff --git a/tests/backend/ip_adapter/test_ip_adapter.py b/tests/AE_other_backend/ip_adapter/test_ip_adapter.py similarity index 96% rename from tests/backend/ip_adapter/test_ip_adapter.py rename to tests/AE_other_backend/ip_adapter/test_ip_adapter.py index 7f634ee1feb..dc723bff42f 100644 --- a/tests/backend/ip_adapter/test_ip_adapter.py +++ b/tests/AE_other_backend/ip_adapter/test_ip_adapter.py @@ -2,7 +2,7 @@ import torch from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher -from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType +from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType from invokeai.backend.util.test_utils import install_and_load_model diff --git a/tests/app/util/test_controlnet_utils.py b/tests/AE_other_backend/test_controlnet_utils.py similarity index 100% rename from tests/app/util/test_controlnet_utils.py rename to tests/AE_other_backend/test_controlnet_utils.py diff --git a/tests/README.txt b/tests/README.txt new file mode 100644 index 00000000000..4e7e1aeae1c --- /dev/null +++ b/tests/README.txt @@ -0,0 +1,8 @@ +The nodes tests need to run before the others in order to avoid a race +condition involving fixture initialization. Please see +https://discord.com/channels/1020123559063990373/1156089584808120382/1156802853323673620 +for an explanation. + +For this reason, the subtests are grouped into alphabetically-ordered +folders. Do not use numeric prefixes (e.g. 00_nodes) because this +breaks python's import system. diff --git a/tests/backend/ip_adapter/__init__.py b/tests/backend/ip_adapter/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/tests/conftest.py b/tests/conftest.py index 8618f5e1025..a063ec77bc6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,8 @@ # conftest.py is a special pytest file. Fixtures defined in this file will be accessible to all tests in this directory # without needing to explicitly import them. (https://docs.pytest.org/en/6.2.x/fixture.html) +from invokeai.app.services.model_install_service import ModelInstallService # noqa: F401 + # We import the model_installer and torch_device fixtures here so that they can be used by all tests. Flake8 does not # play well with fixtures (F401 and F811), so this is cleaner than importing in all files that use these fixtures. -from invokeai.backend.util.test_utils import model_installer, torch_device # noqa: F401 +from invokeai.backend.util.test_utils import torch_device # noqa: F401 diff --git a/tests/nodes/__init__.py b/tests/nodes/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/tests/test_model_manager.py b/tests/test_model_manager.py deleted file mode 100644 index 5a28862e1f2..00000000000 --- a/tests/test_model_manager.py +++ /dev/null @@ -1,47 +0,0 @@ -from pathlib import Path - -import pytest - -from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend import BaseModelType, ModelManager, ModelType, SubModelType - -BASIC_MODEL_NAME = ("SDXL base", BaseModelType.StableDiffusionXL, ModelType.Main) -VAE_OVERRIDE_MODEL_NAME = ("SDXL with VAE", BaseModelType.StableDiffusionXL, ModelType.Main) -VAE_NULL_OVERRIDE_MODEL_NAME = ("SDXL with empty VAE", BaseModelType.StableDiffusionXL, ModelType.Main) - - -@pytest.fixture -def model_manager(datadir) -> ModelManager: - InvokeAIAppConfig.get_config(root=datadir) - return ModelManager(datadir / "configs" / "relative_sub.models.yaml") - - -def test_get_model_names(model_manager: ModelManager): - names = model_manager.model_names() - assert names[:2] == [BASIC_MODEL_NAME, VAE_OVERRIDE_MODEL_NAME] - - -def test_get_model_path_for_diffusers(model_manager: ModelManager, datadir: Path): - model_config = model_manager._get_model_config(BASIC_MODEL_NAME[1], BASIC_MODEL_NAME[0], BASIC_MODEL_NAME[2]) - top_model_path, is_override = model_manager._get_model_path(model_config) - expected_model_path = datadir / "models" / "sdxl" / "main" / "SDXL base 1_0" - assert top_model_path == expected_model_path - assert not is_override - - -def test_get_model_path_for_overridden_vae(model_manager: ModelManager, datadir: Path): - model_config = model_manager._get_model_config( - VAE_OVERRIDE_MODEL_NAME[1], VAE_OVERRIDE_MODEL_NAME[0], VAE_OVERRIDE_MODEL_NAME[2] - ) - vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae) - expected_vae_path = datadir / "models" / "sdxl" / "vae" / "sdxl-vae-fp16-fix" - assert vae_model_path == expected_vae_path - assert is_override - - -def test_get_model_path_for_null_overridden_vae(model_manager: ModelManager, datadir: Path): - model_config = model_manager._get_model_config( - VAE_NULL_OVERRIDE_MODEL_NAME[1], VAE_NULL_OVERRIDE_MODEL_NAME[0], VAE_NULL_OVERRIDE_MODEL_NAME[2] - ) - vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae) - assert not is_override diff --git a/tests/test_model_manager/configs/relative_sub.models.yaml b/tests/test_model_manager/configs/relative_sub.models.yaml deleted file mode 100644 index 2e26710d13f..00000000000 --- a/tests/test_model_manager/configs/relative_sub.models.yaml +++ /dev/null @@ -1,22 +0,0 @@ -__metadata__: - version: 3.0.0 - -sdxl/main/SDXL base: - path: sdxl/main/SDXL base 1_0 - description: SDXL base v1.0 - variant: normal - format: diffusers - -sdxl/main/SDXL with VAE: - path: sdxl/main/SDXL base 1_0 - description: SDXL with customized VAE - vae: sdxl/vae/sdxl-vae-fp16-fix/ - variant: normal - format: diffusers - -sdxl/main/SDXL with empty VAE: - path: sdxl/main/SDXL base 1_0 - description: SDXL with customized VAE - vae: '' - variant: normal - format: diffusers