diff --git a/docs/contributing/INVOCATIONS.md b/docs/contributing/INVOCATIONS.md index 124589f44ce..ce1ee9e808a 100644 --- a/docs/contributing/INVOCATIONS.md +++ b/docs/contributing/INVOCATIONS.md @@ -9,11 +9,15 @@ complex functionality. ## Invocations Directory -InvokeAI Nodes can be found in the `invokeai/app/invocations` directory. These can be used as examples to create your own nodes. +InvokeAI Nodes can be found in the `invokeai/app/invocations` directory. These +can be used as examples to create your own nodes. -New nodes should be added to a subfolder in `nodes` direction found at the root level of the InvokeAI installation location. Nodes added to this folder will be able to be used upon application startup. +New nodes should be added to a subfolder in `nodes` direction found at the root +level of the InvokeAI installation location. Nodes added to this folder will be +able to be used upon application startup. + +Example `nodes` subfolder structure: -Example `nodes` subfolder structure: ```py ├── __init__.py # Invoke-managed custom node loader │ @@ -30,14 +34,14 @@ Example `nodes` subfolder structure: └── fancy_node.py ``` -Each node folder must have an `__init__.py` file that imports its nodes. Only nodes imported in the `__init__.py` file are loaded. - See the README in the nodes folder for more examples: +Each node folder must have an `__init__.py` file that imports its nodes. Only +nodes imported in the `__init__.py` file are loaded. See the README in the nodes +folder for more examples: ```py from .cool_node import CoolInvocation ``` - ## Creating A New Invocation In order to understand the process of creating a new Invocation, let us actually @@ -131,7 +135,6 @@ from invokeai.app.invocations.primitives import ImageField class ResizeInvocation(BaseInvocation): '''Resizes an image''' - # Inputs image: ImageField = InputField(description="The input image") width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image") height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image") @@ -167,7 +170,6 @@ from invokeai.app.invocations.primitives import ImageField class ResizeInvocation(BaseInvocation): '''Resizes an image''' - # Inputs image: ImageField = InputField(description="The input image") width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image") height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image") @@ -197,7 +199,6 @@ from invokeai.app.invocations.image import ImageOutput class ResizeInvocation(BaseInvocation): '''Resizes an image''' - # Inputs image: ImageField = InputField(description="The input image") width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image") height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image") @@ -229,30 +230,17 @@ class ResizeInvocation(BaseInvocation): height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image") def invoke(self, context: InvocationContext) -> ImageOutput: - # Load the image using InvokeAI's predefined Image Service. Returns the PIL image. - image = context.services.images.get_pil_image(self.image.image_name) + # Load the input image as a PIL image + image = context.images.get_pil(self.image.image_name) - # Resizing the image + # Resize the image resized_image = image.resize((self.width, self.height)) - # Save the image using InvokeAI's predefined Image Service. Returns the prepared PIL image. - output_image = context.services.images.create( - image=resized_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - ) - - # Returning the Image - return ImageOutput( - image=ImageField( - image_name=output_image.image_name, - ), - width=output_image.width, - height=output_image.height, - ) + # Save the image + image_dto = context.images.save(image=resized_image) + + # Return an ImageOutput + return ImageOutput.build(image_dto) ``` **Note:** Do not be overwhelmed by the `ImageOutput` process. InvokeAI has a @@ -343,27 +331,25 @@ class ImageColorStringOutput(BaseInvocationOutput): That's all there is to it. - +Custom fields only support connection inputs in the Workflow Editor. diff --git a/docs/contributing/MODEL_MANAGER.md b/docs/contributing/MODEL_MANAGER.md index 880c8b24801..b19699de73d 100644 --- a/docs/contributing/MODEL_MANAGER.md +++ b/docs/contributing/MODEL_MANAGER.md @@ -28,7 +28,7 @@ model. These are the: Hugging Face, as well as discriminating among model versions in Civitai, but can be used for arbitrary content. - * _ModelLoadServiceBase_ (**CURRENTLY UNDER DEVELOPMENT - NOT IMPLEMENTED**) + * _ModelLoadServiceBase_ Responsible for loading a model from disk into RAM and VRAM and getting it ready for inference. @@ -41,10 +41,10 @@ The four main services can be found in * `invokeai/app/services/model_records/` * `invokeai/app/services/model_install/` * `invokeai/app/services/downloads/` -* `invokeai/app/services/model_loader/` (**under development**) +* `invokeai/app/services/model_load/` Code related to the FastAPI web API can be found in -`invokeai/app/api/routers/model_records.py`. +`invokeai/app/api/routers/model_manager_v2.py`. *** @@ -84,10 +84,10 @@ diffusers model. When this happens, `original_hash` is unchanged, but `ModelType`, `ModelFormat` and `BaseModelType` are string enums that are defined in `invokeai.backend.model_manager.config`. They are also imported by, and can be reexported from, -`invokeai.app.services.model_record_service`: +`invokeai.app.services.model_manager.model_records`: ``` -from invokeai.app.services.model_record_service import ModelType, ModelFormat, BaseModelType +from invokeai.app.services.model_records import ModelType, ModelFormat, BaseModelType ``` The `path` field can be absolute or relative. If relative, it is taken @@ -123,7 +123,7 @@ taken to be the `models_dir` directory. `variant` is an enumerated string class with values `normal`, `inpaint` and `depth`. If needed, it can be imported if needed from -either `invokeai.app.services.model_record_service` or +either `invokeai.app.services.model_records` or `invokeai.backend.model_manager.config`. ### ONNXSD2Config @@ -134,7 +134,7 @@ either `invokeai.app.services.model_record_service` or | `upcast_attention` | bool | Model requires its attention module to be upcast | The `SchedulerPredictionType` enum can be imported from either -`invokeai.app.services.model_record_service` or +`invokeai.app.services.model_records` or `invokeai.backend.model_manager.config`. ### Other config classes @@ -157,15 +157,6 @@ indicates that the model is compatible with any of the base models. This works OK for some models, such as the IP Adapter image encoders, but is an all-or-nothing proposition. -Another issue is that the config class hierarchy is paralleled to some -extent by a `ModelBase` class hierarchy defined in -`invokeai.backend.model_manager.models.base` and its subclasses. These -are classes representing the models after they are loaded into RAM and -include runtime information such as load status and bytes used. Some -of the fields, including `name`, `model_type` and `base_model`, are -shared between `ModelConfigBase` and `ModelBase`, and this is a -potential source of confusion. - ## Reading and Writing Model Configuration Records The `ModelRecordService` provides the ability to retrieve model @@ -177,11 +168,11 @@ initialization and can be retrieved within an invocation from the `InvocationContext` object: ``` -store = context.services.model_record_store +store = context.services.model_manager.store ``` or from elsewhere in the code by accessing -`ApiDependencies.invoker.services.model_record_store`. +`ApiDependencies.invoker.services.model_manager.store`. ### Creating a `ModelRecordService` @@ -190,7 +181,7 @@ you can directly create either a `ModelRecordServiceSQL` or a `ModelRecordServiceFile` object: ``` -from invokeai.app.services.model_record_service import ModelRecordServiceSQL, ModelRecordServiceFile +from invokeai.app.services.model_records import ModelRecordServiceSQL, ModelRecordServiceFile store = ModelRecordServiceSQL.from_connection(connection, lock) store = ModelRecordServiceSQL.from_db_file('/path/to/sqlite_database.db') @@ -252,7 +243,7 @@ So a typical startup pattern would be: ``` import sqlite3 from invokeai.app.services.thread import lock -from invokeai.app.services.model_record_service import ModelRecordServiceBase +from invokeai.app.services.model_records import ModelRecordServiceBase from invokeai.app.services.config import InvokeAIAppConfig config = InvokeAIAppConfig.get_config() @@ -260,19 +251,6 @@ db_conn = sqlite3.connect(config.db_path.as_posix(), check_same_thread=False) store = ModelRecordServiceBase.open(config, db_conn, lock) ``` -_A note on simultaneous access to `invokeai.db`_: The current InvokeAI -service architecture for the image and graph databases is careful to -use a shared sqlite3 connection and a thread lock to ensure that two -threads don't attempt to access the database simultaneously. However, -the default `sqlite3` library used by Python reports using -**Serialized** mode, which allows multiple threads to access the -database simultaneously using multiple database connections (see -https://www.sqlite.org/threadsafe.html and -https://ricardoanderegg.com/posts/python-sqlite-thread-safety/). Therefore -it should be safe to allow the record service to open its own SQLite -database connection. Opening a model record service should then be as -simple as `ModelRecordServiceBase.open(config)`. - ### Fetching a Model's Configuration from `ModelRecordServiceBase` Configurations can be retrieved in several ways. @@ -468,6 +446,44 @@ required parameters: Once initialized, the installer will provide the following methods: +#### install_job = installer.heuristic_import(source, [config], [access_token]) + +This is a simplified interface to the installer which takes a source +string, an optional model configuration dictionary and an optional +access token. + +The `source` is a string that can be any of these forms + +1. A path on the local filesystem (`C:\\users\\fred\\model.safetensors`) +2. A Url pointing to a single downloadable model file (`https://civitai.com/models/58390/detail-tweaker-lora-lora`) +3. A HuggingFace repo_id with any of the following formats: + - `model/name` -- entire model + - `model/name:fp32` -- entire model, using the fp32 variant + - `model/name:fp16:vae` -- vae submodel, using the fp16 variant + - `model/name::vae` -- vae submodel, using default precision + - `model/name:fp16:path/to/model.safetensors` -- an individual model file, fp16 variant + - `model/name::path/to/model.safetensors` -- an individual model file, default variant + +Note that by specifying a relative path to the top of the HuggingFace +repo, you can download and install arbitrary models files. + +The variant, if not provided, will be automatically filled in with +`fp32` if the user has requested full precision, and `fp16` +otherwise. If a variant that does not exist is requested, then the +method will install whatever HuggingFace returns as its default +revision. + +`config` is an optional dict of values that will override the +autoprobed values for model type, base, scheduler prediction type, and +so forth. See [Model configuration and +probing](#Model-configuration-and-probing) for details. + +`access_token` is an optional access token for accessing resources +that need authentication. + +The method will return a `ModelInstallJob`. This object is discussed +at length in the following section. + #### install_job = installer.import_model() The `import_model()` method is the core of the installer. The @@ -486,9 +502,10 @@ source2 = LocalModelSource(path='/opt/models/sushi_diffusers') # a local dif source3 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5') # a repo_id source4 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', subfolder='vae') # a subfolder within a repo_id source5 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', variant='fp16') # a named variant of a HF model +source6 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', subfolder='OrangeMix/OrangeMix1.ckpt') # path to an individual model file -source6 = URLModelSource(url='https://civitai.com/api/download/models/63006') # model located at a URL -source7 = URLModelSource(url='https://civitai.com/api/download/models/63006', access_token='letmein') # with an access token +source7 = URLModelSource(url='https://civitai.com/api/download/models/63006') # model located at a URL +source8 = URLModelSource(url='https://civitai.com/api/download/models/63006', access_token='letmein') # with an access token for source in [source1, source2, source3, source4, source5, source6, source7]: install_job = installer.install_model(source) @@ -544,7 +561,6 @@ can be passed to `import_model()`. attributes returned by the model prober. See the section below for details. - #### LocalModelSource This is used for a model that is located on a locally-accessible Posix @@ -737,7 +753,7 @@ and `cancelled`, as well as `in_terminal_state`. The last will return True if the job is in the complete, errored or cancelled states. -#### Model confguration and probing +#### Model configuration and probing The install service uses the `invokeai.backend.model_manager.probe` module during import to determine the model's type, base type, and @@ -776,6 +792,14 @@ returns a list of completed jobs. The optional `timeout` argument will return from the call if jobs aren't completed in the specified time. An argument of 0 (the default) will block indefinitely. +#### jobs = installer.wait_for_job(job, [timeout]) + +Like `wait_for_installs()`, but block until a specific job has +completed or errored, and then return the job. The optional `timeout` +argument will return from the call if the job doesn't complete in the +specified time. An argument of 0 (the default) will block +indefinitely. + #### jobs = installer.list_jobs() Return a list of all active and complete `ModelInstallJobs`. @@ -838,6 +862,31 @@ This method is similar to `unregister()`, but also unconditionally deletes the corresponding model weights file(s), regardless of whether they are inside or outside the InvokeAI models hierarchy. + +#### path = installer.download_and_cache(remote_source, [access_token], [timeout]) + +This utility routine will download the model file located at source, +cache it, and return the path to the cached file. It does not attempt +to determine the model type, probe its configuration values, or +register it with the models database. + +You may provide an access token if the remote source requires +authorization. The call will block indefinitely until the file is +completely downloaded, cancelled or raises an error of some sort. If +you provide a timeout (in seconds), the call will raise a +`TimeoutError` exception if the download hasn't completed in the +specified period. + +You may use this mechanism to request any type of file, not just a +model. The file will be stored in a subdirectory of +`INVOKEAI_ROOT/models/.cache`. If the requested file is found in the +cache, its path will be returned without redownloading it. + +Be aware that the models cache is cleared of infrequently-used files +and directories at regular intervals when the size of the cache +exceeds the value specified in Invoke's `convert_cache` configuration +variable. + #### List[str]=installer.scan_directory(scan_dir: Path, install: bool) This method will recursively scan the directory indicated in @@ -1128,7 +1177,7 @@ job = queue.create_download_job( event_handlers=[my_handler1, my_handler2], # if desired start=True, ) - ``` +``` The `filename` argument forces the downloader to use the specified name for the file rather than the name provided by the remote source, @@ -1171,6 +1220,13 @@ queue or was not created by this queue. This method will block until all the active jobs in the queue have reached a terminal state (completed, errored or cancelled). +#### queue.wait_for_job(job, [timeout]) + +This method will block until the indicated job has reached a terminal +state (completed, errored or cancelled). If the optional timeout is +provided, the call will block for at most timeout seconds, and raise a +TimeoutError otherwise. + #### jobs = queue.list_jobs() This will return a list of all jobs, including ones that have not yet @@ -1449,9 +1505,9 @@ set of keys to the corresponding model config objects. Find all model metadata records that have the given author and return a set of keys to the corresponding model config objects. -# The remainder of this documentation is provisional, pending implementation of the Load service +*** -## Let's get loaded, the lowdown on ModelLoadService +## The Lowdown on the ModelLoadService The `ModelLoadService` is responsible for loading a named model into memory so that it can be used for inference. Despite the fact that it @@ -1465,7 +1521,7 @@ create alternative instances if you wish. ### Creating a ModelLoadService object The class is defined in -`invokeai.app.services.model_loader_service`. It is initialized with +`invokeai.app.services.model_load`. It is initialized with an InvokeAIAppConfig object, from which it gets configuration information such as the user's desired GPU and precision, and with a previously-created `ModelRecordServiceBase` object, from which it @@ -1475,8 +1531,8 @@ Here is a typical initialization pattern: ``` from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.model_record_service import ModelRecordServiceBase -from invokeai.app.services.model_loader_service import ModelLoadService +from invokeai.app.services.model_records import ModelRecordServiceBase +from invokeai.app.services.model_load import ModelLoadService config = InvokeAIAppConfig.get_config() store = ModelRecordServiceBase.open(config) @@ -1487,14 +1543,11 @@ Note that we are relying on the contents of the application configuration to choose the implementation of `ModelRecordServiceBase`. -### get_model(key, [submodel_type], [context]) -> ModelInfo: +### load_model_by_key(key, [submodel_type], [context]) -> LoadedModel -*** TO DO: change to get_model(key, context=None, **kwargs) - -The `get_model()` method, like its similarly-named cousin in -`ModelRecordService`, receives the unique key that identifies the -model. It loads the model into memory, gets the model ready for use, -and returns a `ModelInfo` object. +The `load_model_by_key()` method receives the unique key that +identifies the model. It loads the model into memory, gets the model +ready for use, and returns a `LoadedModel` object. The optional second argument, `subtype` is a `SubModelType` string enum, such as "vae". It is mandatory when used with a main model, and @@ -1504,46 +1557,64 @@ The optional third argument, `context` can be provided by an invocation to trigger model load event reporting. See below for details. -The returned `ModelInfo` object shares some fields in common with -`ModelConfigBase`, but is otherwise a completely different beast: +The returned `LoadedModel` object contains a copy of the configuration +record returned by the model record `get_model()` method, as well as +the in-memory loaded model: -| **Field Name** | **Type** | **Description** | + +| **Attribute Name** | **Type** | **Description** | |----------------|-----------------|------------------| -| `key` | str | The model key derived from the ModelRecordService database | -| `name` | str | Name of this model | -| `base_model` | BaseModelType | Base model for this model | -| `type` | ModelType or SubModelType | Either the model type (non-main) or the submodel type (main models)| -| `location` | Path or str | Location of the model on the filesystem | -| `precision` | torch.dtype | The torch.precision to use for inference | -| `context` | ModelCache.ModelLocker | A context class used to lock the model in VRAM while in use | +| `config` | AnyModelConfig | A copy of the model's configuration record for retrieving base type, etc. | +| `model` | AnyModel | The instantiated model (details below) | +| `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM | -The types for `ModelInfo` and `SubModelType` can be imported from -`invokeai.app.services.model_loader_service`. +Because the loader can return multiple model types, it is typed to +return `AnyModel`, a Union `ModelMixin`, `torch.nn.Module`, +`IAIOnnxRuntimeModel`, `IPAdapter`, `IPAdapterPlus`, and +`EmbeddingModelRaw`. `ModelMixin` is the base class of all diffusers +models, `EmbeddingModelRaw` is used for LoRA and TextualInversion +models. The others are obvious. -To use the model, you use the `ModelInfo` as a context manager using -the following pattern: + +`LoadedModel` acts as a context manager. The context loads the model +into the execution device (e.g. VRAM on CUDA systems), locks the model +in the execution device for the duration of the context, and returns +the model. Use it like this: ``` -model_info = loader.get_model('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae')) +model_info = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae')) with model_info as vae: image = vae.decode(latents)[0] ``` -The `vae` model will stay locked in the GPU during the period of time -it is in the context manager's scope. - -`get_model()` may raise any of the following exceptions: +`get_model_by_key()` may raise any of the following exceptions: -- `UnknownModelException` -- key not in database -- `ModelNotFoundException` -- key in database but model not found at path -- `InvalidModelException` -- the model is guilty of a variety of sins +- `UnknownModelException` -- key not in database +- `ModelNotFoundException` -- key in database but model not found at path +- `NotImplementedException` -- the loader doesn't know how to load this type of model -** TO DO: ** Resolve discrepancy between ModelInfo.location and -ModelConfig.path. +### load_model_by_attr(model_name, base_model, model_type, [submodel], [context]) -> LoadedModel + +This is similar to `load_model_by_key`, but instead it accepts the +combination of the model's name, type and base, which it passes to the +model record config store for retrieval. If successful, this method +returns a `LoadedModel`. It can raise the following exceptions: + +``` +UnknownModelException -- model with these attributes not known +NotImplementedException -- the loader doesn't know how to load this type of model +ValueError -- more than one model matches this combination of base/type/name +``` + +### load_model_by_config(config, [submodel], [context]) -> LoadedModel + +This method takes an `AnyModelConfig` returned by +ModelRecordService.get_model() and returns the corresponding loaded +model. It may raise a `NotImplementedException`. ### Emitting model loading events -When the `context` argument is passed to `get_model()`, it will +When the `context` argument is passed to `load_model_*()`, it will retrieve the invocation event bus from the passed `InvocationContext` object to emit events on the invocation bus. The two events are "model_load_started" and "model_load_completed". Both carry the @@ -1556,10 +1627,104 @@ payload=dict( queue_batch_id=queue_batch_id, graph_execution_state_id=graph_execution_state_id, model_key=model_key, - submodel=submodel, + submodel_type=submodel, hash=model_info.hash, location=str(model_info.location), precision=str(model_info.precision), ) ``` +### Adding Model Loaders + +Model loaders are small classes that inherit from the `ModelLoader` +base class. They typically implement one method `_load_model()` whose +signature is: + +``` +def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, +) -> AnyModel: +``` + +`_load_model()` will be passed the path to the model on disk, an +optional repository variant (used by the diffusers loaders to select, +e.g. the `fp16` variant, and an optional submodel_type for main and +onnx models. + +To install a new loader, place it in +`invokeai/backend/model_manager/load/model_loaders`. Inherit from +`ModelLoader` and use the `@AnyModelLoader.register()` decorator to +indicate what type of models the loader can handle. + +Here is a complete example from `generic_diffusers.py`, which is able +to load several different diffusers types: + +``` +from pathlib import Path +from typing import Optional + +from invokeai.backend.model_manager import ( + AnyModel, + BaseModelType, + ModelFormat, + ModelRepoVariant, + ModelType, + SubModelType, +) +from ..load_base import AnyModelLoader +from ..load_default import ModelLoader + + +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers) +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers) +class GenericDiffusersLoader(ModelLoader): + """Class to load simple diffusers models.""" + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + model_class = self._get_hf_load_class(model_path) + if submodel_type is not None: + raise Exception(f"There are no submodels in models of type {model_class}") + variant = model_variant.value if model_variant else None + result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant) # type: ignore + return result +``` + +Note that a loader can register itself to handle several different +model types. An exception will be raised if more than one loader tries +to register the same model type. + +#### Conversion + +Some models require conversion to diffusers format before they can be +loaded. These loaders should override two additional methods: + +``` +_needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool +_convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path: +``` + +The first method accepts the model configuration, the path to where +the unmodified model is currently installed, and a proposed +destination for the converted model. This method returns True if the +model needs to be converted. It typically does this by comparing the +last modification time of the original model file to the modification +time of the converted model. In some cases you will also want to check +the modification date of the configuration record, in the event that +the user has changed something like the scheduler prediction type that +will require the model to be re-converted. See `controlnet.py` for an +example of this logic. + +The second method accepts the model configuration, the path to the +original model on disk, and the desired output path for the converted +model. It does whatever it needs to do to get the model into diffusers +format, and returns the Path of the resulting model. (The path should +ordinarily be the same as `output_path`.) + diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index c8309e1729e..378961a0557 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -2,9 +2,13 @@ from logging import Logger +import torch + from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory +from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk +from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache from invokeai.app.services.shared.sqlite.sqlite_util import init_db -from invokeai.backend.model_manager.metadata import ModelMetadataStore +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData from invokeai.backend.util.logging import InvokeAILogger from invokeai.version.invokeai_version import __version__ @@ -23,11 +27,7 @@ from ..services.invocation_services import InvocationServices from ..services.invocation_stats.invocation_stats_default import InvocationStatsService from ..services.invoker import Invoker -from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage -from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage -from ..services.model_install import ModelInstallService from ..services.model_manager.model_manager_default import ModelManagerService -from ..services.model_records import ModelRecordServiceSQL from ..services.names.names_default import SimpleNameService from ..services.session_processor.session_processor_default import DefaultSessionProcessor from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue @@ -68,6 +68,9 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger logger.debug(f"Internet connectivity is {config.internet_available}") output_folder = config.output_path + if output_folder is None: + raise ValueError("Output folder is not set") + image_files = DiskImageFileStorage(f"{output_folder}/images") db = init_db(config=config, logger=logger, image_files=image_files) @@ -84,17 +87,15 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger image_records = SqliteImageRecordStorage(db=db) images = ImageService() invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) - latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")) - model_manager = ModelManagerService(config, logger) - model_record_service = ModelRecordServiceSQL(db=db) + tensors = ObjectSerializerForwardCache( + ObjectSerializerDisk[torch.Tensor](output_folder / "tensors", ephemeral=True) + ) + conditioning = ObjectSerializerForwardCache( + ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True) + ) download_queue_service = DownloadQueueService(event_bus=events) - metadata_store = ModelMetadataStore(db=db) - model_install_service = ModelInstallService( - app_config=config, - record_store=model_record_service, - download_queue=download_queue_service, - metadata_store=metadata_store, - event_bus=events, + model_manager = ModelManagerService.build_model_manager( + app_config=configuration, db=db, download_queue=download_queue_service, events=events ) names = SimpleNameService() performance_statistics = InvocationStatsService() @@ -117,12 +118,9 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger image_records=image_records, images=images, invocation_cache=invocation_cache, - latents=latents, logger=logger, model_manager=model_manager, - model_records=model_record_service, download_queue=download_queue_service, - model_install=model_install_service, names=names, performance_statistics=performance_statistics, processor=processor, @@ -131,6 +129,8 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger session_queue=session_queue, urls=urls, workflow_records=workflow_records, + tensors=tensors, + conditioning=conditioning, ) ApiDependencies.invoker = Invoker(services) diff --git a/invokeai/app/api/routers/download_queue.py b/invokeai/app/api/routers/download_queue.py index 92b658c3708..a6e53c7a5c4 100644 --- a/invokeai/app/api/routers/download_queue.py +++ b/invokeai/app/api/routers/download_queue.py @@ -36,7 +36,7 @@ async def list_downloads() -> List[DownloadJob]: 400: {"description": "Bad request"}, }, ) -async def prune_downloads(): +async def prune_downloads() -> Response: """Prune completed and errored jobs.""" queue = ApiDependencies.invoker.services.download_queue queue.prune_jobs() @@ -55,7 +55,7 @@ async def download( ) -> DownloadJob: """Download the source URL to the file or directory indicted in dest.""" queue = ApiDependencies.invoker.services.download_queue - return queue.download(source, dest, priority, access_token) + return queue.download(source, Path(dest), priority, access_token) @download_queue_router.get( @@ -87,7 +87,7 @@ async def get_download_job( ) async def cancel_download_job( id: int = Path(description="ID of the download job to cancel."), -): +) -> Response: """Cancel a download job using its ID.""" try: queue = ApiDependencies.invoker.services.download_queue @@ -105,7 +105,7 @@ async def cancel_download_job( 204: {"description": "Download jobs have been cancelled"}, }, ) -async def cancel_all_download_jobs(): +async def cancel_all_download_jobs() -> Response: """Cancel all download jobs.""" ApiDependencies.invoker.services.download_queue.cancel_all_jobs() return Response(status_code=204) diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index 125896b8d3a..cc60ad1be83 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -8,7 +8,7 @@ from PIL import Image from pydantic import BaseModel, Field, ValidationError -from invokeai.app.invocations.baseinvocation import MetadataField, MetadataFieldValidator +from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO from invokeai.app.services.shared.pagination import OffsetPaginatedResults diff --git a/invokeai/app/api/routers/model_manager_v2.py b/invokeai/app/api/routers/model_manager_v2.py new file mode 100644 index 00000000000..2471e0d8c9b --- /dev/null +++ b/invokeai/app/api/routers/model_manager_v2.py @@ -0,0 +1,759 @@ +# Copyright (c) 2023 Lincoln D. Stein +"""FastAPI route for model configuration records.""" + +import pathlib +import shutil +from hashlib import sha1 +from random import randbytes +from typing import Any, Dict, List, Optional, Set + +from fastapi import Body, Path, Query, Response +from fastapi.routing import APIRouter +from pydantic import BaseModel, ConfigDict +from starlette.exceptions import HTTPException +from typing_extensions import Annotated + +from invokeai.app.services.model_install import ModelInstallJob, ModelSource +from invokeai.app.services.model_records import ( + DuplicateModelException, + InvalidModelException, + ModelRecordOrderBy, + ModelSummary, + UnknownModelException, +) +from invokeai.app.services.shared.pagination import PaginatedResults +from invokeai.backend.model_manager.config import ( + AnyModelConfig, + BaseModelType, + MainCheckpointConfig, + ModelFormat, + ModelType, + SubModelType, +) +from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger +from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata + +from ..dependencies import ApiDependencies + +model_manager_v2_router = APIRouter(prefix="/v2/models", tags=["model_manager_v2"]) + + +class ModelsList(BaseModel): + """Return list of configs.""" + + models: List[AnyModelConfig] + + model_config = ConfigDict(use_enum_values=True) + + +class ModelTagSet(BaseModel): + """Return tags for a set of models.""" + + key: str + name: str + author: str + tags: Set[str] + + +############################################################################## +# These are example inputs and outputs that are used in places where Swagger +# is unable to generate a correct example. +############################################################################## +example_model_config = { + "path": "string", + "name": "string", + "base": "sd-1", + "type": "main", + "format": "checkpoint", + "config": "string", + "key": "string", + "original_hash": "string", + "current_hash": "string", + "description": "string", + "source": "string", + "last_modified": 0, + "vae": "string", + "variant": "normal", + "prediction_type": "epsilon", + "repo_variant": "fp16", + "upcast_attention": False, + "ztsnr_training": False, +} + +example_model_input = { + "path": "/path/to/model", + "name": "model_name", + "base": "sd-1", + "type": "main", + "format": "checkpoint", + "config": "configs/stable-diffusion/v1-inference.yaml", + "description": "Model description", + "vae": None, + "variant": "normal", +} + +example_model_metadata = { + "name": "ip_adapter_sd_image_encoder", + "author": "InvokeAI", + "tags": [ + "transformers", + "safetensors", + "clip_vision_model", + "endpoints_compatible", + "region:us", + "has_space", + "license:apache-2.0", + ], + "files": [ + { + "url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/README.md", + "path": "ip_adapter_sd_image_encoder/README.md", + "size": 628, + "sha256": None, + }, + { + "url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/config.json", + "path": "ip_adapter_sd_image_encoder/config.json", + "size": 560, + "sha256": None, + }, + { + "url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/model.safetensors", + "path": "ip_adapter_sd_image_encoder/model.safetensors", + "size": 2528373448, + "sha256": "6ca9667da1ca9e0b0f75e46bb030f7e011f44f86cbfb8d5a36590fcd7507b030", + }, + ], + "type": "huggingface", + "id": "InvokeAI/ip_adapter_sd_image_encoder", + "tag_dict": {"license": "apache-2.0"}, + "last_modified": "2023-09-23T17:33:25Z", +} + +############################################################################## +# ROUTES +############################################################################## + + +@model_manager_v2_router.get( + "/", + operation_id="list_model_records", +) +async def list_model_records( + base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"), + model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"), + model_name: Optional[str] = Query(default=None, description="Exact match on the name of the model"), + model_format: Optional[ModelFormat] = Query( + default=None, description="Exact match on the format of the model (e.g. 'diffusers')" + ), +) -> ModelsList: + """Get a list of models.""" + record_store = ApiDependencies.invoker.services.model_manager.store + found_models: list[AnyModelConfig] = [] + if base_models: + for base_model in base_models: + found_models.extend( + record_store.search_by_attr( + base_model=base_model, model_type=model_type, model_name=model_name, model_format=model_format + ) + ) + else: + found_models.extend( + record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format) + ) + return ModelsList(models=found_models) + + +@model_manager_v2_router.get( + "/i/{key}", + operation_id="get_model_record", + responses={ + 200: { + "description": "The model configuration was retrieved successfully", + "content": {"application/json": {"example": example_model_config}}, + }, + 400: {"description": "Bad request"}, + 404: {"description": "The model could not be found"}, + }, +) +async def get_model_record( + key: str = Path(description="Key of the model record to fetch."), +) -> AnyModelConfig: + """Get a model record""" + record_store = ApiDependencies.invoker.services.model_manager.store + try: + config: AnyModelConfig = record_store.get_model(key) + return config + except UnknownModelException as e: + raise HTTPException(status_code=404, detail=str(e)) + + +@model_manager_v2_router.get("/summary", operation_id="list_model_summary") +async def list_model_summary( + page: int = Query(default=0, description="The page to get"), + per_page: int = Query(default=10, description="The number of models per page"), + order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"), +) -> PaginatedResults[ModelSummary]: + """Gets a page of model summary data.""" + record_store = ApiDependencies.invoker.services.model_manager.store + results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by) + return results + + +@model_manager_v2_router.get( + "/meta/i/{key}", + operation_id="get_model_metadata", + responses={ + 200: { + "description": "The model metadata was retrieved successfully", + "content": {"application/json": {"example": example_model_metadata}}, + }, + 400: {"description": "Bad request"}, + 404: {"description": "No metadata available"}, + }, +) +async def get_model_metadata( + key: str = Path(description="Key of the model repo metadata to fetch."), +) -> Optional[AnyModelRepoMetadata]: + """Get a model metadata object.""" + record_store = ApiDependencies.invoker.services.model_manager.store + result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key) + if not result: + raise HTTPException(status_code=404, detail="No metadata for a model with this key") + return result + + +@model_manager_v2_router.get( + "/tags", + operation_id="list_tags", +) +async def list_tags() -> Set[str]: + """Get a unique set of all the model tags.""" + record_store = ApiDependencies.invoker.services.model_manager.store + result: Set[str] = record_store.list_tags() + return result + + +@model_manager_v2_router.get( + "/tags/search", + operation_id="search_by_metadata_tags", +) +async def search_by_metadata_tags( + tags: Set[str] = Query(default=None, description="Tags to search for"), +) -> ModelsList: + """Get a list of models.""" + record_store = ApiDependencies.invoker.services.model_manager.store + results = record_store.search_by_metadata_tag(tags) + return ModelsList(models=results) + + +@model_manager_v2_router.patch( + "/i/{key}", + operation_id="update_model_record", + responses={ + 200: { + "description": "The model was updated successfully", + "content": {"application/json": {"example": example_model_config}}, + }, + 400: {"description": "Bad request"}, + 404: {"description": "The model could not be found"}, + 409: {"description": "There is already a model corresponding to the new name"}, + }, + status_code=200, +) +async def update_model_record( + key: Annotated[str, Path(description="Unique key of model")], + info: Annotated[ + AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input) + ], +) -> AnyModelConfig: + """Update model contents with a new config. If the model name or base fields are changed, then the model is renamed.""" + logger = ApiDependencies.invoker.services.logger + record_store = ApiDependencies.invoker.services.model_manager.store + try: + model_response: AnyModelConfig = record_store.update_model(key, config=info) + logger.info(f"Updated model: {key}") + except UnknownModelException as e: + raise HTTPException(status_code=404, detail=str(e)) + except ValueError as e: + logger.error(str(e)) + raise HTTPException(status_code=409, detail=str(e)) + return model_response + + +@model_manager_v2_router.delete( + "/i/{key}", + operation_id="del_model_record", + responses={ + 204: {"description": "Model deleted successfully"}, + 404: {"description": "Model not found"}, + }, + status_code=204, +) +async def del_model_record( + key: str = Path(description="Unique key of model to remove from model registry."), +) -> Response: + """ + Delete model record from database. + + The configuration record will be removed. The corresponding weights files will be + deleted as well if they reside within the InvokeAI "models" directory. + """ + logger = ApiDependencies.invoker.services.logger + + try: + installer = ApiDependencies.invoker.services.model_manager.install + installer.delete(key) + logger.info(f"Deleted model: {key}") + return Response(status_code=204) + except UnknownModelException as e: + logger.error(str(e)) + raise HTTPException(status_code=404, detail=str(e)) + + +@model_manager_v2_router.post( + "/i/", + operation_id="add_model_record", + responses={ + 201: { + "description": "The model added successfully", + "content": {"application/json": {"example": example_model_config}}, + }, + 409: {"description": "There is already a model corresponding to this path or repo_id"}, + 415: {"description": "Unrecognized file/folder format"}, + }, + status_code=201, +) +async def add_model_record( + config: Annotated[ + AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input) + ], +) -> AnyModelConfig: + """Add a model using the configuration information appropriate for its type.""" + logger = ApiDependencies.invoker.services.logger + record_store = ApiDependencies.invoker.services.model_manager.store + if config.key == "": + config.key = sha1(randbytes(100)).hexdigest() + logger.info(f"Created model {config.key} for {config.name}") + try: + record_store.add_model(config.key, config) + except DuplicateModelException as e: + logger.error(str(e)) + raise HTTPException(status_code=409, detail=str(e)) + except InvalidModelException as e: + logger.error(str(e)) + raise HTTPException(status_code=415) + + # now fetch it out + result: AnyModelConfig = record_store.get_model(config.key) + return result + + +@model_manager_v2_router.post( + "/heuristic_import", + operation_id="heuristic_import_model", + responses={ + 201: {"description": "The model imported successfully"}, + 415: {"description": "Unrecognized file/folder format"}, + 424: {"description": "The model appeared to import successfully, but could not be found in the model manager"}, + 409: {"description": "There is already a model corresponding to this path or repo_id"}, + }, + status_code=201, +) +async def heuristic_import( + source: str, + config: Optional[Dict[str, Any]] = Body( + description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ", + default=None, + example={"name": "modelT", "description": "antique cars"}, + ), + access_token: Optional[str] = None, +) -> ModelInstallJob: + """Install a model using a string identifier. + + `source` can be any of the following. + + 1. A path on the local filesystem ('C:\\users\\fred\\model.safetensors') + 2. A Url pointing to a single downloadable model file + 3. A HuggingFace repo_id with any of the following formats: + - model/name + - model/name:fp16:vae + - model/name::vae -- use default precision + - model/name:fp16:path/to/model.safetensors + - model/name::path/to/model.safetensors + + `config` is an optional dict containing model configuration values that will override + the ones that are probed automatically. + + `access_token` is an optional access token for use with Urls that require + authentication. + + Models will be downloaded, probed, configured and installed in a + series of background threads. The return object has `status` attribute + that can be used to monitor progress. + + See the documentation for `import_model_record` for more information on + interpreting the job information returned by this route. + """ + logger = ApiDependencies.invoker.services.logger + + try: + installer = ApiDependencies.invoker.services.model_manager.install + result: ModelInstallJob = installer.heuristic_import( + source=source, + config=config, + ) + logger.info(f"Started installation of {source}") + except UnknownModelException as e: + logger.error(str(e)) + raise HTTPException(status_code=424, detail=str(e)) + except InvalidModelException as e: + logger.error(str(e)) + raise HTTPException(status_code=415) + except ValueError as e: + logger.error(str(e)) + raise HTTPException(status_code=409, detail=str(e)) + return result + + +@model_manager_v2_router.post( + "/install", + operation_id="import_model", + responses={ + 201: {"description": "The model imported successfully"}, + 415: {"description": "Unrecognized file/folder format"}, + 424: {"description": "The model appeared to import successfully, but could not be found in the model manager"}, + 409: {"description": "There is already a model corresponding to this path or repo_id"}, + }, + status_code=201, +) +async def import_model( + source: ModelSource, + config: Optional[Dict[str, Any]] = Body( + description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ", + default=None, + ), +) -> ModelInstallJob: + """Install a model using its local path, repo_id, or remote URL. + + Models will be downloaded, probed, configured and installed in a + series of background threads. The return object has `status` attribute + that can be used to monitor progress. + + The source object is a discriminated Union of LocalModelSource, + HFModelSource and URLModelSource. Set the "type" field to the + appropriate value: + + * To install a local path using LocalModelSource, pass a source of form: + ``` + { + "type": "local", + "path": "/path/to/model", + "inplace": false + } + ``` + The "inplace" flag, if true, will register the model in place in its + current filesystem location. Otherwise, the model will be copied + into the InvokeAI models directory. + + * To install a HuggingFace repo_id using HFModelSource, pass a source of form: + ``` + { + "type": "hf", + "repo_id": "stabilityai/stable-diffusion-2.0", + "variant": "fp16", + "subfolder": "vae", + "access_token": "f5820a918aaf01" + } + ``` + The `variant`, `subfolder` and `access_token` fields are optional. + + * To install a remote model using an arbitrary URL, pass: + ``` + { + "type": "url", + "url": "http://www.civitai.com/models/123456", + "access_token": "f5820a918aaf01" + } + ``` + The `access_token` field is optonal + + The model's configuration record will be probed and filled in + automatically. To override the default guesses, pass "metadata" + with a Dict containing the attributes you wish to override. + + Installation occurs in the background. Either use list_model_install_jobs() + to poll for completion, or listen on the event bus for the following events: + + * "model_install_running" + * "model_install_completed" + * "model_install_error" + + On successful completion, the event's payload will contain the field "key" + containing the installed ID of the model. On an error, the event's payload + will contain the fields "error_type" and "error" describing the nature of the + error and its traceback, respectively. + + """ + logger = ApiDependencies.invoker.services.logger + + try: + installer = ApiDependencies.invoker.services.model_manager.install + result: ModelInstallJob = installer.import_model( + source=source, + config=config, + ) + logger.info(f"Started installation of {source}") + except UnknownModelException as e: + logger.error(str(e)) + raise HTTPException(status_code=424, detail=str(e)) + except InvalidModelException as e: + logger.error(str(e)) + raise HTTPException(status_code=415) + except ValueError as e: + logger.error(str(e)) + raise HTTPException(status_code=409, detail=str(e)) + return result + + +@model_manager_v2_router.get( + "/import", + operation_id="list_model_install_jobs", +) +async def list_model_install_jobs() -> List[ModelInstallJob]: + """Return the list of model install jobs. + + Install jobs have a numeric `id`, a `status`, and other fields that provide information on + the nature of the job and its progress. The `status` is one of: + + * "waiting" -- Job is waiting in the queue to run + * "downloading" -- Model file(s) are downloading + * "running" -- Model has downloaded and the model probing and registration process is running + * "completed" -- Installation completed successfully + * "error" -- An error occurred. Details will be in the "error_type" and "error" fields. + * "cancelled" -- Job was cancelled before completion. + + Once completed, information about the model such as its size, base + model, type, and metadata can be retrieved from the `config_out` + field. For multi-file models such as diffusers, information on individual files + can be retrieved from `download_parts`. + + See the example and schema below for more information. + """ + jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_manager.install.list_jobs() + return jobs + + +@model_manager_v2_router.get( + "/import/{id}", + operation_id="get_model_install_job", + responses={ + 200: {"description": "Success"}, + 404: {"description": "No such job"}, + }, +) +async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob: + """ + Return model install job corresponding to the given source. See the documentation for 'List Model Install Jobs' + for information on the format of the return value. + """ + try: + result: ModelInstallJob = ApiDependencies.invoker.services.model_manager.install.get_job_by_id(id) + return result + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) + + +@model_manager_v2_router.delete( + "/import/{id}", + operation_id="cancel_model_install_job", + responses={ + 201: {"description": "The job was cancelled successfully"}, + 415: {"description": "No such job"}, + }, + status_code=201, +) +async def cancel_model_install_job(id: int = Path(description="Model install job ID")) -> None: + """Cancel the model install job(s) corresponding to the given job ID.""" + installer = ApiDependencies.invoker.services.model_manager.install + try: + job = installer.get_job_by_id(id) + except ValueError as e: + raise HTTPException(status_code=415, detail=str(e)) + installer.cancel_job(job) + + +@model_manager_v2_router.patch( + "/import", + operation_id="prune_model_install_jobs", + responses={ + 204: {"description": "All completed and errored jobs have been pruned"}, + 400: {"description": "Bad request"}, + }, +) +async def prune_model_install_jobs() -> Response: + """Prune all completed and errored jobs from the install job list.""" + ApiDependencies.invoker.services.model_manager.install.prune_jobs() + return Response(status_code=204) + + +@model_manager_v2_router.patch( + "/sync", + operation_id="sync_models_to_config", + responses={ + 204: {"description": "Model config record database resynced with files on disk"}, + 400: {"description": "Bad request"}, + }, +) +async def sync_models_to_config() -> Response: + """ + Traverse the models and autoimport directories. + + Model files without a corresponding + record in the database are added. Orphan records without a models file are deleted. + """ + ApiDependencies.invoker.services.model_manager.install.sync_to_config() + return Response(status_code=204) + + +@model_manager_v2_router.put( + "/convert/{key}", + operation_id="convert_model", + responses={ + 200: { + "description": "Model converted successfully", + "content": {"application/json": {"example": example_model_config}}, + }, + 400: {"description": "Bad request"}, + 404: {"description": "Model not found"}, + 409: {"description": "There is already a model registered at this location"}, + }, +) +async def convert_model( + key: str = Path(description="Unique key of the safetensors main model to convert to diffusers format."), +) -> AnyModelConfig: + """ + Permanently convert a model into diffusers format, replacing the safetensors version. + Note that during the conversion process the key and model hash will change. + The return value is the model configuration for the converted model. + """ + logger = ApiDependencies.invoker.services.logger + loader = ApiDependencies.invoker.services.model_manager.load + store = ApiDependencies.invoker.services.model_manager.store + installer = ApiDependencies.invoker.services.model_manager.install + + try: + model_config = store.get_model(key) + except UnknownModelException as e: + logger.error(str(e)) + raise HTTPException(status_code=424, detail=str(e)) + + if not isinstance(model_config, MainCheckpointConfig): + logger.error(f"The model with key {key} is not a main checkpoint model.") + raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.") + + # loading the model will convert it into a cached diffusers file + loader.load_model_by_config(model_config, submodel_type=SubModelType.Scheduler) + + # Get the path of the converted model from the loader + cache_path = loader.convert_cache.cache_path(key) + assert cache_path.exists() + + # temporarily rename the original safetensors file so that there is no naming conflict + original_name = model_config.name + model_config.name = f"{original_name}.DELETE" + store.update_model(key, config=model_config) + + # install the diffusers + try: + new_key = installer.install_path( + cache_path, + config={ + "name": original_name, + "description": model_config.description, + "original_hash": model_config.original_hash, + "source": model_config.source, + }, + ) + except DuplicateModelException as e: + logger.error(str(e)) + raise HTTPException(status_code=409, detail=str(e)) + + # get the original metadata + if orig_metadata := store.get_metadata(key): + store.metadata_store.add_metadata(new_key, orig_metadata) + + # delete the original safetensors file + installer.delete(key) + + # delete the cached version + shutil.rmtree(cache_path) + + # return the config record for the new diffusers directory + new_config: AnyModelConfig = store.get_model(new_key) + return new_config + + +@model_manager_v2_router.put( + "/merge", + operation_id="merge", + responses={ + 200: { + "description": "Model converted successfully", + "content": {"application/json": {"example": example_model_config}}, + }, + 400: {"description": "Bad request"}, + 404: {"description": "Model not found"}, + 409: {"description": "There is already a model registered at this location"}, + }, +) +async def merge( + keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3), + merged_model_name: Optional[str] = Body(description="Name of destination model", default=None), + alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5), + force: bool = Body( + description="Force merging of models created with different versions of diffusers", + default=False, + ), + interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None), + merge_dest_directory: Optional[str] = Body( + description="Save the merged model to the designated directory (with 'merged_model_name' appended)", + default=None, + ), +) -> AnyModelConfig: + """ + Merge diffusers models. The process is controlled by a set parameters provided in the body of the request. + ``` + Argument Description [default] + -------- ---------------------- + keys List of 2-3 model keys to merge together. All models must use the same base type. + merged_model_name Name for the merged model [Concat model names] + alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5] + force If true, force the merge even if the models were generated by different versions of the diffusers library [False] + interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum] + merge_dest_directory Specify a directory to store the merged model in [models directory] + ``` + """ + logger = ApiDependencies.invoker.services.logger + try: + logger.info(f"Merging models: {keys} into {merge_dest_directory or ''}/{merged_model_name}") + dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None + installer = ApiDependencies.invoker.services.model_manager.install + merger = ModelMerger(installer) + model_names = [installer.record_store.get_model(x).name for x in keys] + response = merger.merge_diffusion_models_and_save( + model_keys=keys, + merged_model_name=merged_model_name or "+".join(model_names), + alpha=alpha, + interp=interp, + force=force, + merge_dest_directory=dest, + ) + except UnknownModelException: + raise HTTPException( + status_code=404, + detail=f"One or more of the models '{keys}' not found", + ) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + return response diff --git a/invokeai/app/api/routers/model_records.py b/invokeai/app/api/routers/model_records.py deleted file mode 100644 index f9a3e408985..00000000000 --- a/invokeai/app/api/routers/model_records.py +++ /dev/null @@ -1,472 +0,0 @@ -# Copyright (c) 2023 Lincoln D. Stein -"""FastAPI route for model configuration records.""" - -import pathlib -from hashlib import sha1 -from random import randbytes -from typing import Any, Dict, List, Optional, Set - -from fastapi import Body, Path, Query, Response -from fastapi.routing import APIRouter -from pydantic import BaseModel, ConfigDict -from starlette.exceptions import HTTPException -from typing_extensions import Annotated - -from invokeai.app.services.model_install import ModelInstallJob, ModelSource -from invokeai.app.services.model_records import ( - DuplicateModelException, - InvalidModelException, - ModelRecordOrderBy, - ModelSummary, - UnknownModelException, -) -from invokeai.app.services.shared.pagination import PaginatedResults -from invokeai.backend.model_manager.config import ( - AnyModelConfig, - BaseModelType, - ModelFormat, - ModelType, -) -from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger -from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata - -from ..dependencies import ApiDependencies - -model_records_router = APIRouter(prefix="/v1/model/record", tags=["model_manager_v2_unstable"]) - - -class ModelsList(BaseModel): - """Return list of configs.""" - - models: List[AnyModelConfig] - - model_config = ConfigDict(use_enum_values=True) - - -class ModelTagSet(BaseModel): - """Return tags for a set of models.""" - - key: str - name: str - author: str - tags: Set[str] - - -@model_records_router.get( - "/", - operation_id="list_model_records", -) -async def list_model_records( - base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"), - model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"), - model_name: Optional[str] = Query(default=None, description="Exact match on the name of the model"), - model_format: Optional[ModelFormat] = Query( - default=None, description="Exact match on the format of the model (e.g. 'diffusers')" - ), -) -> ModelsList: - """Get a list of models.""" - record_store = ApiDependencies.invoker.services.model_records - found_models: list[AnyModelConfig] = [] - if base_models: - for base_model in base_models: - found_models.extend( - record_store.search_by_attr( - base_model=base_model, model_type=model_type, model_name=model_name, model_format=model_format - ) - ) - else: - found_models.extend( - record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format) - ) - return ModelsList(models=found_models) - - -@model_records_router.get( - "/i/{key}", - operation_id="get_model_record", - responses={ - 200: {"description": "Success"}, - 400: {"description": "Bad request"}, - 404: {"description": "The model could not be found"}, - }, -) -async def get_model_record( - key: str = Path(description="Key of the model record to fetch."), -) -> AnyModelConfig: - """Get a model record""" - record_store = ApiDependencies.invoker.services.model_records - try: - return record_store.get_model(key) - except UnknownModelException as e: - raise HTTPException(status_code=404, detail=str(e)) - - -@model_records_router.get("/meta", operation_id="list_model_summary") -async def list_model_summary( - page: int = Query(default=0, description="The page to get"), - per_page: int = Query(default=10, description="The number of models per page"), - order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"), -) -> PaginatedResults[ModelSummary]: - """Gets a page of model summary data.""" - return ApiDependencies.invoker.services.model_records.list_models(page=page, per_page=per_page, order_by=order_by) - - -@model_records_router.get( - "/meta/i/{key}", - operation_id="get_model_metadata", - responses={ - 200: {"description": "Success"}, - 400: {"description": "Bad request"}, - 404: {"description": "No metadata available"}, - }, -) -async def get_model_metadata( - key: str = Path(description="Key of the model repo metadata to fetch."), -) -> Optional[AnyModelRepoMetadata]: - """Get a model metadata object.""" - record_store = ApiDependencies.invoker.services.model_records - result = record_store.get_metadata(key) - if not result: - raise HTTPException(status_code=404, detail="No metadata for a model with this key") - return result - - -@model_records_router.get( - "/tags", - operation_id="list_tags", -) -async def list_tags() -> Set[str]: - """Get a unique set of all the model tags.""" - record_store = ApiDependencies.invoker.services.model_records - return record_store.list_tags() - - -@model_records_router.get( - "/tags/search", - operation_id="search_by_metadata_tags", -) -async def search_by_metadata_tags( - tags: Set[str] = Query(default=None, description="Tags to search for"), -) -> ModelsList: - """Get a list of models.""" - record_store = ApiDependencies.invoker.services.model_records - results = record_store.search_by_metadata_tag(tags) - return ModelsList(models=results) - - -@model_records_router.patch( - "/i/{key}", - operation_id="update_model_record", - responses={ - 200: {"description": "The model was updated successfully"}, - 400: {"description": "Bad request"}, - 404: {"description": "The model could not be found"}, - 409: {"description": "There is already a model corresponding to the new name"}, - }, - status_code=200, - response_model=AnyModelConfig, -) -async def update_model_record( - key: Annotated[str, Path(description="Unique key of model")], - info: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")], -) -> AnyModelConfig: - """Update model contents with a new config. If the model name or base fields are changed, then the model is renamed.""" - logger = ApiDependencies.invoker.services.logger - record_store = ApiDependencies.invoker.services.model_records - try: - model_response = record_store.update_model(key, config=info) - logger.info(f"Updated model: {key}") - except UnknownModelException as e: - raise HTTPException(status_code=404, detail=str(e)) - except ValueError as e: - logger.error(str(e)) - raise HTTPException(status_code=409, detail=str(e)) - return model_response - - -@model_records_router.delete( - "/i/{key}", - operation_id="del_model_record", - responses={ - 204: {"description": "Model deleted successfully"}, - 404: {"description": "Model not found"}, - }, - status_code=204, -) -async def del_model_record( - key: str = Path(description="Unique key of model to remove from model registry."), -) -> Response: - """ - Delete model record from database. - - The configuration record will be removed. The corresponding weights files will be - deleted as well if they reside within the InvokeAI "models" directory. - """ - logger = ApiDependencies.invoker.services.logger - - try: - installer = ApiDependencies.invoker.services.model_install - installer.delete(key) - logger.info(f"Deleted model: {key}") - return Response(status_code=204) - except UnknownModelException as e: - logger.error(str(e)) - raise HTTPException(status_code=404, detail=str(e)) - - -@model_records_router.post( - "/i/", - operation_id="add_model_record", - responses={ - 201: {"description": "The model added successfully"}, - 409: {"description": "There is already a model corresponding to this path or repo_id"}, - 415: {"description": "Unrecognized file/folder format"}, - }, - status_code=201, -) -async def add_model_record( - config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")], -) -> AnyModelConfig: - """Add a model using the configuration information appropriate for its type.""" - logger = ApiDependencies.invoker.services.logger - record_store = ApiDependencies.invoker.services.model_records - if config.key == "": - config.key = sha1(randbytes(100)).hexdigest() - logger.info(f"Created model {config.key} for {config.name}") - try: - record_store.add_model(config.key, config) - except DuplicateModelException as e: - logger.error(str(e)) - raise HTTPException(status_code=409, detail=str(e)) - except InvalidModelException as e: - logger.error(str(e)) - raise HTTPException(status_code=415) - - # now fetch it out - return record_store.get_model(config.key) - - -@model_records_router.post( - "/import", - operation_id="import_model_record", - responses={ - 201: {"description": "The model imported successfully"}, - 415: {"description": "Unrecognized file/folder format"}, - 424: {"description": "The model appeared to import successfully, but could not be found in the model manager"}, - 409: {"description": "There is already a model corresponding to this path or repo_id"}, - }, - status_code=201, -) -async def import_model( - source: ModelSource, - config: Optional[Dict[str, Any]] = Body( - description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ", - default=None, - ), -) -> ModelInstallJob: - """Add a model using its local path, repo_id, or remote URL. - - Models will be downloaded, probed, configured and installed in a - series of background threads. The return object has `status` attribute - that can be used to monitor progress. - - The source object is a discriminated Union of LocalModelSource, - HFModelSource and URLModelSource. Set the "type" field to the - appropriate value: - - * To install a local path using LocalModelSource, pass a source of form: - `{ - "type": "local", - "path": "/path/to/model", - "inplace": false - }` - The "inplace" flag, if true, will register the model in place in its - current filesystem location. Otherwise, the model will be copied - into the InvokeAI models directory. - - * To install a HuggingFace repo_id using HFModelSource, pass a source of form: - `{ - "type": "hf", - "repo_id": "stabilityai/stable-diffusion-2.0", - "variant": "fp16", - "subfolder": "vae", - "access_token": "f5820a918aaf01" - }` - The `variant`, `subfolder` and `access_token` fields are optional. - - * To install a remote model using an arbitrary URL, pass: - `{ - "type": "url", - "url": "http://www.civitai.com/models/123456", - "access_token": "f5820a918aaf01" - }` - The `access_token` field is optonal - - The model's configuration record will be probed and filled in - automatically. To override the default guesses, pass "metadata" - with a Dict containing the attributes you wish to override. - - Installation occurs in the background. Either use list_model_install_jobs() - to poll for completion, or listen on the event bus for the following events: - - "model_install_running" - "model_install_completed" - "model_install_error" - - On successful completion, the event's payload will contain the field "key" - containing the installed ID of the model. On an error, the event's payload - will contain the fields "error_type" and "error" describing the nature of the - error and its traceback, respectively. - - """ - logger = ApiDependencies.invoker.services.logger - - try: - installer = ApiDependencies.invoker.services.model_install - result: ModelInstallJob = installer.import_model( - source=source, - config=config, - ) - logger.info(f"Started installation of {source}") - except UnknownModelException as e: - logger.error(str(e)) - raise HTTPException(status_code=424, detail=str(e)) - except InvalidModelException as e: - logger.error(str(e)) - raise HTTPException(status_code=415) - except ValueError as e: - logger.error(str(e)) - raise HTTPException(status_code=409, detail=str(e)) - return result - - -@model_records_router.get( - "/import", - operation_id="list_model_install_jobs", -) -async def list_model_install_jobs() -> List[ModelInstallJob]: - """Return list of model install jobs.""" - jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_install.list_jobs() - return jobs - - -@model_records_router.get( - "/import/{id}", - operation_id="get_model_install_job", - responses={ - 200: {"description": "Success"}, - 404: {"description": "No such job"}, - }, -) -async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob: - """Return model install job corresponding to the given source.""" - try: - return ApiDependencies.invoker.services.model_install.get_job_by_id(id) - except ValueError as e: - raise HTTPException(status_code=404, detail=str(e)) - - -@model_records_router.delete( - "/import/{id}", - operation_id="cancel_model_install_job", - responses={ - 201: {"description": "The job was cancelled successfully"}, - 415: {"description": "No such job"}, - }, - status_code=201, -) -async def cancel_model_install_job(id: int = Path(description="Model install job ID")) -> None: - """Cancel the model install job(s) corresponding to the given job ID.""" - installer = ApiDependencies.invoker.services.model_install - try: - job = installer.get_job_by_id(id) - except ValueError as e: - raise HTTPException(status_code=415, detail=str(e)) - installer.cancel_job(job) - - -@model_records_router.patch( - "/import", - operation_id="prune_model_install_jobs", - responses={ - 204: {"description": "All completed and errored jobs have been pruned"}, - 400: {"description": "Bad request"}, - }, -) -async def prune_model_install_jobs() -> Response: - """Prune all completed and errored jobs from the install job list.""" - ApiDependencies.invoker.services.model_install.prune_jobs() - return Response(status_code=204) - - -@model_records_router.patch( - "/sync", - operation_id="sync_models_to_config", - responses={ - 204: {"description": "Model config record database resynced with files on disk"}, - 400: {"description": "Bad request"}, - }, -) -async def sync_models_to_config() -> Response: - """ - Traverse the models and autoimport directories. - - Model files without a corresponding - record in the database are added. Orphan records without a models file are deleted. - """ - ApiDependencies.invoker.services.model_install.sync_to_config() - return Response(status_code=204) - - -@model_records_router.put( - "/merge", - operation_id="merge", -) -async def merge( - keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3), - merged_model_name: Optional[str] = Body(description="Name of destination model", default=None), - alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5), - force: bool = Body( - description="Force merging of models created with different versions of diffusers", - default=False, - ), - interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None), - merge_dest_directory: Optional[str] = Body( - description="Save the merged model to the designated directory (with 'merged_model_name' appended)", - default=None, - ), -) -> AnyModelConfig: - """ - Merge diffusers models. - - keys: List of 2-3 model keys to merge together. All models must use the same base type. - merged_model_name: Name for the merged model [Concat model names] - alpha: Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5] - force: If true, force the merge even if the models were generated by different versions of the diffusers library [False] - interp: Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum] - merge_dest_directory: Specify a directory to store the merged model in [models directory] - """ - print(f"here i am, keys={keys}") - logger = ApiDependencies.invoker.services.logger - try: - logger.info(f"Merging models: {keys} into {merge_dest_directory or ''}/{merged_model_name}") - dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None - installer = ApiDependencies.invoker.services.model_install - merger = ModelMerger(installer) - model_names = [installer.record_store.get_model(x).name for x in keys] - response = merger.merge_diffusion_models_and_save( - model_keys=keys, - merged_model_name=merged_model_name or "+".join(model_names), - alpha=alpha, - interp=interp, - force=force, - merge_dest_directory=dest, - ) - except UnknownModelException: - raise HTTPException( - status_code=404, - detail=f"One or more of the models '{keys}' not found", - ) - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) - return response diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 8f83820cf89..0aa7aa0ecba 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -8,8 +8,7 @@ from pydantic import BaseModel, ConfigDict, Field, TypeAdapter from starlette.exceptions import HTTPException -from invokeai.backend import BaseModelType, ModelType -from invokeai.backend.model_management import MergeInterpolationMethod +from invokeai.backend.model_management import BaseModelType, MergeInterpolationMethod, ModelType from invokeai.backend.model_management.models import ( OPENAPI_MODEL_CONFIGS, InvalidModelException, diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 6294083d0e1..1831b54c13c 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -6,6 +6,7 @@ from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles from invokeai.version.invokeai_version import __version__ +from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra from .services.config import InvokeAIAppConfig app_config = InvokeAIAppConfig.get_config() @@ -47,8 +48,7 @@ boards, download_queue, images, - model_records, - models, + model_manager_v2, session_queue, sessions, utilities, @@ -57,8 +57,6 @@ from .api.sockets import SocketIO from .invocations.baseinvocation import ( BaseInvocation, - InputFieldJSONSchemaExtra, - OutputFieldJSONSchemaExtra, UIConfigBase, ) @@ -115,8 +113,7 @@ async def shutdown_event() -> None: app.include_router(sessions.session_router, prefix="/api") app.include_router(utilities.utilities_router, prefix="/api") -app.include_router(models.models_router, prefix="/api") -app.include_router(model_records.model_records_router, prefix="/api") +app.include_router(model_manager_v2.model_manager_v2_router, prefix="/api") app.include_router(download_queue.download_queue_router, prefix="/api") app.include_router(images.images_router, prefix="/api") app.include_router(boards.boards_router, prefix="/api") diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index d9e0c7ba0d2..3243714937f 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -12,13 +12,16 @@ from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union, cast import semver -from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, create_model -from pydantic.fields import FieldInfo, _Unset +from pydantic import BaseModel, ConfigDict, Field, create_model +from pydantic.fields import FieldInfo from pydantic_core import PydanticUndefined +from invokeai.app.invocations.fields import ( + FieldKind, + Input, +) from invokeai.app.services.config.config_default import InvokeAIAppConfig -from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID -from invokeai.app.shared.fields import FieldDescriptions +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.metaenum import MetaEnum from invokeai.app.util.misc import uuid_string from invokeai.backend.util.logging import InvokeAILogger @@ -52,393 +55,6 @@ class Classification(str, Enum, metaclass=MetaEnum): Prototype = "prototype" -class Input(str, Enum, metaclass=MetaEnum): - """ - The type of input a field accepts. - - `Input.Direct`: The field must have its value provided directly, when the invocation and field \ - are instantiated. - - `Input.Connection`: The field must have its value provided by a connection. - - `Input.Any`: The field may have its value provided either directly or by a connection. - """ - - Connection = "connection" - Direct = "direct" - Any = "any" - - -class FieldKind(str, Enum, metaclass=MetaEnum): - """ - The kind of field. - - `Input`: An input field on a node. - - `Output`: An output field on a node. - - `Internal`: A field which is treated as an input, but cannot be used in node definitions. Metadata is - one example. It is provided to nodes via the WithMetadata class, and we want to reserve the field name - "metadata" for this on all nodes. `FieldKind` is used to short-circuit the field name validation logic, - allowing "metadata" for that field. - - `NodeAttribute`: The field is a node attribute. These are fields which are not inputs or outputs, - but which are used to store information about the node. For example, the `id` and `type` fields are node - attributes. - - The presence of this in `json_schema_extra["field_kind"]` is used when initializing node schemas on app - startup, and when generating the OpenAPI schema for the workflow editor. - """ - - Input = "input" - Output = "output" - Internal = "internal" - NodeAttribute = "node_attribute" - - -class UIType(str, Enum, metaclass=MetaEnum): - """ - Type hints for the UI for situations in which the field type is not enough to infer the correct UI type. - - - Model Fields - The most common node-author-facing use will be for model fields. Internally, there is no difference - between SD-1, SD-2 and SDXL model fields - they all use the class `MainModelField`. To ensure the - base-model-specific UI is rendered, use e.g. `ui_type=UIType.SDXLMainModelField` to indicate that - the field is an SDXL main model field. - - - Any Field - We cannot infer the usage of `typing.Any` via schema parsing, so you *must* use `ui_type=UIType.Any` to - indicate that the field accepts any type. Use with caution. This cannot be used on outputs. - - - Scheduler Field - Special handling in the UI is needed for this field, which otherwise would be parsed as a plain enum field. - - - Internal Fields - Similar to the Any Field, the `collect` and `iterate` nodes make use of `typing.Any`. To facilitate - handling these types in the client, we use `UIType._Collection` and `UIType._CollectionItem`. These - should not be used by node authors. - - - DEPRECATED Fields - These types are deprecated and should not be used by node authors. A warning will be logged if one is - used, and the type will be ignored. They are included here for backwards compatibility. - """ - - # region Model Field Types - SDXLMainModel = "SDXLMainModelField" - SDXLRefinerModel = "SDXLRefinerModelField" - ONNXModel = "ONNXModelField" - VaeModel = "VAEModelField" - LoRAModel = "LoRAModelField" - ControlNetModel = "ControlNetModelField" - IPAdapterModel = "IPAdapterModelField" - # endregion - - # region Misc Field Types - Scheduler = "SchedulerField" - Any = "AnyField" - # endregion - - # region Internal Field Types - _Collection = "CollectionField" - _CollectionItem = "CollectionItemField" - # endregion - - # region DEPRECATED - Boolean = "DEPRECATED_Boolean" - Color = "DEPRECATED_Color" - Conditioning = "DEPRECATED_Conditioning" - Control = "DEPRECATED_Control" - Float = "DEPRECATED_Float" - Image = "DEPRECATED_Image" - Integer = "DEPRECATED_Integer" - Latents = "DEPRECATED_Latents" - String = "DEPRECATED_String" - BooleanCollection = "DEPRECATED_BooleanCollection" - ColorCollection = "DEPRECATED_ColorCollection" - ConditioningCollection = "DEPRECATED_ConditioningCollection" - ControlCollection = "DEPRECATED_ControlCollection" - FloatCollection = "DEPRECATED_FloatCollection" - ImageCollection = "DEPRECATED_ImageCollection" - IntegerCollection = "DEPRECATED_IntegerCollection" - LatentsCollection = "DEPRECATED_LatentsCollection" - StringCollection = "DEPRECATED_StringCollection" - BooleanPolymorphic = "DEPRECATED_BooleanPolymorphic" - ColorPolymorphic = "DEPRECATED_ColorPolymorphic" - ConditioningPolymorphic = "DEPRECATED_ConditioningPolymorphic" - ControlPolymorphic = "DEPRECATED_ControlPolymorphic" - FloatPolymorphic = "DEPRECATED_FloatPolymorphic" - ImagePolymorphic = "DEPRECATED_ImagePolymorphic" - IntegerPolymorphic = "DEPRECATED_IntegerPolymorphic" - LatentsPolymorphic = "DEPRECATED_LatentsPolymorphic" - StringPolymorphic = "DEPRECATED_StringPolymorphic" - MainModel = "DEPRECATED_MainModel" - UNet = "DEPRECATED_UNet" - Vae = "DEPRECATED_Vae" - CLIP = "DEPRECATED_CLIP" - Collection = "DEPRECATED_Collection" - CollectionItem = "DEPRECATED_CollectionItem" - Enum = "DEPRECATED_Enum" - WorkflowField = "DEPRECATED_WorkflowField" - IsIntermediate = "DEPRECATED_IsIntermediate" - BoardField = "DEPRECATED_BoardField" - MetadataItem = "DEPRECATED_MetadataItem" - MetadataItemCollection = "DEPRECATED_MetadataItemCollection" - MetadataItemPolymorphic = "DEPRECATED_MetadataItemPolymorphic" - MetadataDict = "DEPRECATED_MetadataDict" - # endregion - - -class UIComponent(str, Enum, metaclass=MetaEnum): - """ - The type of UI component to use for a field, used to override the default components, which are - inferred from the field type. - """ - - None_ = "none" - Textarea = "textarea" - Slider = "slider" - - -class InputFieldJSONSchemaExtra(BaseModel): - """ - Extra attributes to be added to input fields and their OpenAPI schema. Used during graph execution, - and by the workflow editor during schema parsing and UI rendering. - """ - - input: Input - orig_required: bool - field_kind: FieldKind - default: Optional[Any] = None - orig_default: Optional[Any] = None - ui_hidden: bool = False - ui_type: Optional[UIType] = None - ui_component: Optional[UIComponent] = None - ui_order: Optional[int] = None - ui_choice_labels: Optional[dict[str, str]] = None - - model_config = ConfigDict( - validate_assignment=True, - json_schema_serialization_defaults_required=True, - ) - - -class OutputFieldJSONSchemaExtra(BaseModel): - """ - Extra attributes to be added to input fields and their OpenAPI schema. Used by the workflow editor - during schema parsing and UI rendering. - """ - - field_kind: FieldKind - ui_hidden: bool - ui_type: Optional[UIType] - ui_order: Optional[int] - - model_config = ConfigDict( - validate_assignment=True, - json_schema_serialization_defaults_required=True, - ) - - -def InputField( - # copied from pydantic's Field - # TODO: Can we support default_factory? - default: Any = _Unset, - default_factory: Callable[[], Any] | None = _Unset, - title: str | None = _Unset, - description: str | None = _Unset, - pattern: str | None = _Unset, - strict: bool | None = _Unset, - gt: float | None = _Unset, - ge: float | None = _Unset, - lt: float | None = _Unset, - le: float | None = _Unset, - multiple_of: float | None = _Unset, - allow_inf_nan: bool | None = _Unset, - max_digits: int | None = _Unset, - decimal_places: int | None = _Unset, - min_length: int | None = _Unset, - max_length: int | None = _Unset, - # custom - input: Input = Input.Any, - ui_type: Optional[UIType] = None, - ui_component: Optional[UIComponent] = None, - ui_hidden: bool = False, - ui_order: Optional[int] = None, - ui_choice_labels: Optional[dict[str, str]] = None, -) -> Any: - """ - Creates an input field for an invocation. - - This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/latest/api/fields/#pydantic.fields.Field) \ - that adds a few extra parameters to support graph execution and the node editor UI. - - :param Input input: [Input.Any] The kind of input this field requires. \ - `Input.Direct` means a value must be provided on instantiation. \ - `Input.Connection` means the value must be provided by a connection. \ - `Input.Any` means either will do. - - :param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \ - In some situations, the field's type is not enough to infer the correct UI type. \ - For example, model selection fields should render a dropdown UI component to select a model. \ - Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \ - `MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \ - `UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field. - - :param UIComponent ui_component: [None] Optionally specifies a specific component to use in the UI. \ - The UI will always render a suitable component, but sometimes you want something different than the default. \ - For example, a `string` field will default to a single-line input, but you may want a multi-line textarea instead. \ - For this case, you could provide `UIComponent.Textarea`. - - :param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. - - :param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. - - :param dict[str, str] ui_choice_labels: [None] Specifies the labels to use for the choices in an enum field. - """ - - json_schema_extra_ = InputFieldJSONSchemaExtra( - input=input, - ui_type=ui_type, - ui_component=ui_component, - ui_hidden=ui_hidden, - ui_order=ui_order, - ui_choice_labels=ui_choice_labels, - field_kind=FieldKind.Input, - orig_required=True, - ) - - """ - There is a conflict between the typing of invocation definitions and the typing of an invocation's - `invoke()` function. - - On instantiation of a node, the invocation definition is used to create the python class. At this time, - any number of fields may be optional, because they may be provided by connections. - - On calling of `invoke()`, however, those fields may be required. - - For example, consider an ResizeImageInvocation with an `image: ImageField` field. - - `image` is required during the call to `invoke()`, but when the python class is instantiated, - the field may not be present. This is fine, because that image field will be provided by a - connection from an ancestor node, which outputs an image. - - This means we want to type the `image` field as optional for the node class definition, but required - for the `invoke()` function. - - If we use `typing.Optional` in the node class definition, the field will be typed as optional in the - `invoke()` method, and we'll have to do a lot of runtime checks to ensure the field is present - or - any static type analysis tools will complain. - - To get around this, in node class definitions, we type all fields correctly for the `invoke()` function, - but secretly make them optional in `InputField()`. We also store the original required bool and/or default - value. When we call `invoke()`, we use this stored information to do an additional check on the class. - """ - - if default_factory is not _Unset and default_factory is not None: - default = default_factory() - logger.warn('"default_factory" is not supported, calling it now to set "default"') - - # These are the args we may wish pass to the pydantic `Field()` function - field_args = { - "default": default, - "title": title, - "description": description, - "pattern": pattern, - "strict": strict, - "gt": gt, - "ge": ge, - "lt": lt, - "le": le, - "multiple_of": multiple_of, - "allow_inf_nan": allow_inf_nan, - "max_digits": max_digits, - "decimal_places": decimal_places, - "min_length": min_length, - "max_length": max_length, - } - - # We only want to pass the args that were provided, otherwise the `Field()`` function won't work as expected - provided_args = {k: v for (k, v) in field_args.items() if v is not PydanticUndefined} - - # Because we are manually making fields optional, we need to store the original required bool for reference later - json_schema_extra_.orig_required = default is PydanticUndefined - - # Make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one - if input is Input.Any or input is Input.Connection: - default_ = None if default is PydanticUndefined else default - provided_args.update({"default": default_}) - if default is not PydanticUndefined: - # Before invoking, we'll check for the original default value and set it on the field if the field has no value - json_schema_extra_.default = default - json_schema_extra_.orig_default = default - elif default is not PydanticUndefined: - default_ = default - provided_args.update({"default": default_}) - json_schema_extra_.orig_default = default_ - - return Field( - **provided_args, - json_schema_extra=json_schema_extra_.model_dump(exclude_none=True), - ) - - -def OutputField( - # copied from pydantic's Field - default: Any = _Unset, - title: str | None = _Unset, - description: str | None = _Unset, - pattern: str | None = _Unset, - strict: bool | None = _Unset, - gt: float | None = _Unset, - ge: float | None = _Unset, - lt: float | None = _Unset, - le: float | None = _Unset, - multiple_of: float | None = _Unset, - allow_inf_nan: bool | None = _Unset, - max_digits: int | None = _Unset, - decimal_places: int | None = _Unset, - min_length: int | None = _Unset, - max_length: int | None = _Unset, - # custom - ui_type: Optional[UIType] = None, - ui_hidden: bool = False, - ui_order: Optional[int] = None, -) -> Any: - """ - Creates an output field for an invocation output. - - This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \ - that adds a few extra parameters to support graph execution and the node editor UI. - - :param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \ - In some situations, the field's type is not enough to infer the correct UI type. \ - For example, model selection fields should render a dropdown UI component to select a model. \ - Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \ - `MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \ - `UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field. - - :param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \ - - :param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \ - """ - return Field( - default=default, - title=title, - description=description, - pattern=pattern, - strict=strict, - gt=gt, - ge=ge, - lt=lt, - le=le, - multiple_of=multiple_of, - allow_inf_nan=allow_inf_nan, - max_digits=max_digits, - decimal_places=decimal_places, - min_length=min_length, - max_length=max_length, - json_schema_extra=OutputFieldJSONSchemaExtra( - ui_type=ui_type, - ui_hidden=ui_hidden, - ui_order=ui_order, - field_kind=FieldKind.Output, - ).model_dump(exclude_none=True), - ) - - class UIConfigBase(BaseModel): """ Provides additional node configuration to the UI. @@ -460,33 +76,6 @@ class UIConfigBase(BaseModel): ) -class InvocationContext: - """Initialized and provided to on execution of invocations.""" - - services: InvocationServices - graph_execution_state_id: str - queue_id: str - queue_item_id: int - queue_batch_id: str - workflow: Optional[WorkflowWithoutID] - - def __init__( - self, - services: InvocationServices, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - workflow: Optional[WorkflowWithoutID], - ): - self.services = services - self.graph_execution_state_id = graph_execution_state_id - self.queue_id = queue_id - self.queue_item_id = queue_item_id - self.queue_batch_id = queue_batch_id - self.workflow = workflow - - class BaseInvocationOutput(BaseModel): """ Base class for all invocation outputs. @@ -632,7 +221,7 @@ def invoke(self, context: InvocationContext) -> BaseInvocationOutput: """Invoke with provided context and return outputs.""" pass - def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput: + def invoke_internal(self, context: InvocationContext, services: "InvocationServices") -> BaseInvocationOutput: """ Internal invoke method, calls `invoke()` after some prep. Handles optional fields that are required to call `invoke()` and invocation cache. @@ -657,23 +246,23 @@ def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput: raise MissingInputException(self.model_fields["type"].default, field_name) # skip node cache codepath if it's disabled - if context.services.configuration.node_cache_size == 0: + if services.configuration.node_cache_size == 0: return self.invoke(context) output: BaseInvocationOutput if self.use_cache: - key = context.services.invocation_cache.create_key(self) - cached_value = context.services.invocation_cache.get(key) + key = services.invocation_cache.create_key(self) + cached_value = services.invocation_cache.get(key) if cached_value is None: - context.services.logger.debug(f'Invocation cache miss for type "{self.get_type()}": {self.id}') + services.logger.debug(f'Invocation cache miss for type "{self.get_type()}": {self.id}') output = self.invoke(context) - context.services.invocation_cache.save(key, output) + services.invocation_cache.save(key, output) return output else: - context.services.logger.debug(f'Invocation cache hit for type "{self.get_type()}": {self.id}') + services.logger.debug(f'Invocation cache hit for type "{self.get_type()}": {self.id}') return cached_value else: - context.services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}') + services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}') return self.invoke(context) id: str = Field( @@ -714,9 +303,7 @@ def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput: "workflow", } -RESERVED_INPUT_FIELD_NAMES = { - "metadata", -} +RESERVED_INPUT_FIELD_NAMES = {"metadata", "board"} RESERVED_OUTPUT_FIELD_NAMES = {"type"} @@ -926,37 +513,3 @@ def wrapper(cls: Type[TBaseInvocationOutput]) -> Type[TBaseInvocationOutput]: return cls return wrapper - - -class MetadataField(RootModel): - """ - Pydantic model for metadata with custom root of type dict[str, Any]. - Metadata is stored without a strict schema. - """ - - root: dict[str, Any] = Field(description="The metadata") - - -MetadataFieldValidator = TypeAdapter(MetadataField) - - -class WithMetadata(BaseModel): - metadata: Optional[MetadataField] = Field( - default=None, - description=FieldDescriptions.metadata, - json_schema_extra=InputFieldJSONSchemaExtra( - field_kind=FieldKind.Internal, - input=Input.Connection, - orig_required=False, - ).model_dump(exclude_none=True), - ) - - -class WithWorkflow: - workflow = None - - def __init_subclass__(cls) -> None: - logger.warn( - f"{cls.__module__.split('.')[0]}.{cls.__name__}: WithWorkflow is deprecated. Use `context.workflow` to access the workflow." - ) - super().__init_subclass__() diff --git a/invokeai/app/invocations/collections.py b/invokeai/app/invocations/collections.py index 4c7b6f94cd4..e02291980f9 100644 --- a/invokeai/app/invocations/collections.py +++ b/invokeai/app/invocations/collections.py @@ -5,9 +5,11 @@ from pydantic import ValidationInfo, field_validator from invokeai.app.invocations.primitives import IntegerCollectionOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.misc import SEED_MAX -from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation +from .baseinvocation import BaseInvocation, invocation +from .fields import InputField @invocation( diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 49c62cff564..5159d5b89c5 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -1,40 +1,43 @@ -from dataclasses import dataclass -from typing import List, Optional, Union +from typing import Iterator, List, Optional, Tuple, Union import torch from compel import Compel, ReturnedEmbeddingsType from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment +from transformers import CLIPTokenizer -from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput -from invokeai.app.shared.fields import FieldDescriptions +import invokeai.backend.util.logging as logger +from invokeai.app.invocations.fields import ( + FieldDescriptions, + Input, + InputField, + OutputField, + UIComponent, +) +from invokeai.app.invocations.primitives import ConditioningOutput +from invokeai.app.services.model_records import UnknownModelException +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.app.util.ti_utils import extract_ti_triggers_from_prompt +from invokeai.backend.embeddings.lora import LoRAModelRaw +from invokeai.backend.embeddings.model_patcher import ModelPatcher +from invokeai.backend.embeddings.textual_inversion import TextualInversionModelRaw +from invokeai.backend.model_manager import ModelType from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, + ConditioningFieldData, ExtraConditioningInfo, SDXLConditioningInfo, ) +from invokeai.backend.util.devices import torch_dtype -from ...backend.model_management.lora import ModelPatcher -from ...backend.model_management.models import ModelNotFoundException, ModelType -from ...backend.util.devices import torch_dtype -from ..util.ti_utils import extract_ti_triggers_from_prompt from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - Input, - InputField, - InvocationContext, - OutputField, - UIComponent, invocation, invocation_output, ) from .model import ClipField - -@dataclass -class ConditioningFieldData: - conditionings: List[BasicConditioningInfo] - # unconditioned: Optional[torch.Tensor] +# unconditioned: Optional[torch.Tensor] # class ConditioningAlgo(str, Enum): @@ -48,7 +51,7 @@ class ConditioningFieldData: title="Prompt", tags=["prompt", "compel"], category="conditioning", - version="1.0.0", + version="1.0.1", ) class CompelInvocation(BaseInvocation): """Parse prompt using compel package to conditioning.""" @@ -66,49 +69,34 @@ class CompelInvocation(BaseInvocation): @torch.no_grad() def invoke(self, context: InvocationContext) -> ConditioningOutput: - tokenizer_info = context.services.model_manager.get_model( - **self.clip.tokenizer.model_dump(), - context=context, - ) - text_encoder_info = context.services.model_manager.get_model( - **self.clip.text_encoder.model_dump(), - context=context, - ) + tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump()) + text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump()) - def _lora_loader(): + def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in self.clip.loras: - lora_info = context.services.model_manager.get_model( - **lora.model_dump(exclude={"weight"}), context=context - ) - yield (lora_info.context.model, lora.weight) + lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) + assert isinstance(lora_info.model, LoRAModelRaw) + yield (lora_info.model, lora.weight) del lora_info return - # loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] + # loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] ti_list = [] for trigger in extract_ti_triggers_from_prompt(self.prompt): name = trigger[1:-1] try: - ti_list.append( - ( - name, - context.services.model_manager.get_model( - model_name=name, - base_model=self.clip.text_encoder.base_model, - model_type=ModelType.TextualInversion, - context=context, - ).context.model, - ) - ) - except ModelNotFoundException: + loaded_model = context.models.load(**self.clip.text_encoder.model_dump()).model + assert isinstance(loaded_model, TextualInversionModelRaw) + ti_list.append((name, loaded_model)) + except UnknownModelException: # print(e) # import traceback # print(traceback.format_exc()) print(f'Warn: trigger: "{trigger}" not found') with ( - ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as ( + ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as ( tokenizer, ti_manager, ), @@ -116,7 +104,7 @@ def _lora_loader(): # Apply the LoRA after text_encoder has been moved to its target device for faster patching. ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. - ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers), + ModelPatcher.apply_clip_skip(text_encoder_info.model, self.clip.skipped_layers), ): compel = Compel( tokenizer=tokenizer, @@ -128,7 +116,7 @@ def _lora_loader(): conjunction = Compel.parse_prompt_string(self.prompt) - if context.services.configuration.log_tokenization: + if context.config.get().log_tokenization: log_tokenization_for_conjunction(conjunction, tokenizer) c, options = compel.build_conditioning_tensor_for_conjunction(conjunction) @@ -149,17 +137,14 @@ def _lora_loader(): ] ) - conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" - context.services.latents.save(conditioning_name, conditioning_data) + conditioning_name = context.conditioning.save(conditioning_data) - return ConditioningOutput( - conditioning=ConditioningField( - conditioning_name=conditioning_name, - ), - ) + return ConditioningOutput.build(conditioning_name) class SDXLPromptInvocationBase: + """Prompt processor for SDXL models.""" + def run_clip_compel( self, context: InvocationContext, @@ -168,26 +153,21 @@ def run_clip_compel( get_pooled: bool, lora_prefix: str, zero_on_empty: bool, - ): - tokenizer_info = context.services.model_manager.get_model( - **clip_field.tokenizer.model_dump(), - context=context, - ) - text_encoder_info = context.services.model_manager.get_model( - **clip_field.text_encoder.model_dump(), - context=context, - ) + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]: + tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump()) + text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump()) # return zero on empty if prompt == "" and zero_on_empty: - cpu_text_encoder = text_encoder_info.context.model + cpu_text_encoder = text_encoder_info.model + assert isinstance(cpu_text_encoder, torch.nn.Module) c = torch.zeros( ( 1, cpu_text_encoder.config.max_position_embeddings, cpu_text_encoder.config.hidden_size, ), - dtype=text_encoder_info.context.cache.precision, + dtype=cpu_text_encoder.dtype, ) if get_pooled: c_pooled = torch.zeros( @@ -198,40 +178,36 @@ def run_clip_compel( c_pooled = None return c, c_pooled, None - def _lora_loader(): + def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in clip_field.loras: - lora_info = context.services.model_manager.get_model( - **lora.model_dump(exclude={"weight"}), context=context - ) - yield (lora_info.context.model, lora.weight) + lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) + lora_model = lora_info.model + assert isinstance(lora_model, LoRAModelRaw) + yield (lora_model, lora.weight) del lora_info return - # loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] + # loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] ti_list = [] for trigger in extract_ti_triggers_from_prompt(prompt): name = trigger[1:-1] try: - ti_list.append( - ( - name, - context.services.model_manager.get_model( - model_name=name, - base_model=clip_field.text_encoder.base_model, - model_type=ModelType.TextualInversion, - context=context, - ).context.model, - ) - ) - except ModelNotFoundException: + ti_model = context.models.load_by_attrs( + model_name=name, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion + ).model + assert isinstance(ti_model, TextualInversionModelRaw) + ti_list.append((name, ti_model)) + except UnknownModelException: # print(e) # import traceback # print(traceback.format_exc()) - print(f'Warn: trigger: "{trigger}" not found') + logger.warning(f'trigger: "{trigger}" not found') + except ValueError: + logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models') with ( - ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as ( + ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as ( tokenizer, ti_manager, ), @@ -239,7 +215,7 @@ def _lora_loader(): # Apply the LoRA after text_encoder has been moved to its target device for faster patching. ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. - ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers), + ModelPatcher.apply_clip_skip(text_encoder_info.model, clip_field.skipped_layers), ): compel = Compel( tokenizer=tokenizer, @@ -253,7 +229,7 @@ def _lora_loader(): conjunction = Compel.parse_prompt_string(prompt) - if context.services.configuration.log_tokenization: + if context.config.get().log_tokenization: # TODO: better logging for and syntax log_tokenization_for_conjunction(conjunction, tokenizer) @@ -286,7 +262,7 @@ def _lora_loader(): title="SDXL Prompt", tags=["sdxl", "compel", "prompt"], category="conditioning", - version="1.0.0", + version="1.0.1", ) class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): """Parse prompt using compel package to conditioning.""" @@ -357,6 +333,7 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput: dim=1, ) + assert c2_pooled is not None conditioning_data = ConditioningFieldData( conditionings=[ SDXLConditioningInfo( @@ -368,14 +345,9 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput: ] ) - conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" - context.services.latents.save(conditioning_name, conditioning_data) + conditioning_name = context.conditioning.save(conditioning_data) - return ConditioningOutput( - conditioning=ConditioningField( - conditioning_name=conditioning_name, - ), - ) + return ConditioningOutput.build(conditioning_name) @invocation( @@ -383,7 +355,7 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput: title="SDXL Refiner Prompt", tags=["sdxl", "compel", "prompt"], category="conditioning", - version="1.0.0", + version="1.0.1", ) class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): """Parse prompt using compel package to conditioning.""" @@ -410,6 +382,7 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput: add_time_ids = torch.tensor([original_size + crop_coords + (self.aesthetic_score,)]) + assert c2_pooled is not None conditioning_data = ConditioningFieldData( conditionings=[ SDXLConditioningInfo( @@ -421,14 +394,9 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput: ] ) - conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" - context.services.latents.save(conditioning_name, conditioning_data) + conditioning_name = context.conditioning.save(conditioning_data) - return ConditioningOutput( - conditioning=ConditioningField( - conditioning_name=conditioning_name, - ), - ) + return ConditioningOutput.build(conditioning_name) @invocation_output("clip_skip_output") @@ -459,9 +427,9 @@ def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput: def get_max_token_count( - tokenizer, + tokenizer: CLIPTokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], - truncate_if_too_long=False, + truncate_if_too_long: bool = False, ) -> int: if type(prompt) is Blend: blend: Blend = prompt @@ -473,7 +441,9 @@ def get_max_token_count( return len(get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long)) -def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> List[str]: +def get_tokens_for_prompt_object( + tokenizer: CLIPTokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long: bool = True +) -> List[str]: if type(parsed_prompt) is Blend: raise ValueError("Blend is not supported here - you need to get tokens for each of its .children") @@ -486,24 +456,29 @@ def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, trun for x in parsed_prompt.children ] text = " ".join(text_fragments) - tokens = tokenizer.tokenize(text) + tokens: List[str] = tokenizer.tokenize(text) if truncate_if_too_long: max_tokens_length = tokenizer.model_max_length - 2 # typically 75 tokens = tokens[0:max_tokens_length] return tokens -def log_tokenization_for_conjunction(c: Conjunction, tokenizer, display_label_prefix=None): +def log_tokenization_for_conjunction( + c: Conjunction, tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None +) -> None: display_label_prefix = display_label_prefix or "" for i, p in enumerate(c.prompts): if len(c.prompts) > 1: this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})" else: + assert display_label_prefix is not None this_display_label_prefix = display_label_prefix log_tokenization_for_prompt_object(p, tokenizer, display_label_prefix=this_display_label_prefix) -def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None): +def log_tokenization_for_prompt_object( + p: Union[Blend, FlattenedPrompt], tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None +) -> None: display_label_prefix = display_label_prefix or "" if type(p) is Blend: blend: Blend = p @@ -543,7 +518,12 @@ def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokeniz log_tokenization_for_text(text, tokenizer, display_label=display_label_prefix) -def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False): +def log_tokenization_for_text( + text: str, + tokenizer: CLIPTokenizer, + display_label: Optional[str] = None, + truncate_if_too_long: Optional[bool] = False, +) -> None: """shows how the prompt is tokenized # usually tokens have '' to indicate end-of-word, # but for readability it has been replaced with ' ' diff --git a/invokeai/app/invocations/constants.py b/invokeai/app/invocations/constants.py new file mode 100644 index 00000000000..795e7a3b604 --- /dev/null +++ b/invokeai/app/invocations/constants.py @@ -0,0 +1,14 @@ +from typing import Literal + +from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP + +LATENT_SCALE_FACTOR = 8 +""" +HACK: Many nodes are currently hard-coded to use a fixed latent scale factor of 8. This is fragile, and will need to +be addressed if future models use a different latent scale factor. Also, note that there may be places where the scale +factor is hard-coded to a literal '8' rather than using this constant. +The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1. +""" + +SCHEDULER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())] +"""A literal type representing the valid scheduler names.""" diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 1f9342985a0..8542134fff0 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -23,24 +23,26 @@ ) from controlnet_aux.util import HWC3, ade_palette from PIL import Image -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from pydantic import BaseModel, Field, field_validator, model_validator -from invokeai.app.invocations.primitives import ImageField, ImageOutput +from invokeai.app.invocations.fields import ( + FieldDescriptions, + ImageField, + Input, + InputField, + OutputField, + WithBoard, + WithMetadata, +) +from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.util import validate_begin_end_step, validate_weights -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin -from invokeai.app.shared.fields import FieldDescriptions +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.image_util.depth_anything import DepthAnythingDetector from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector -from ...backend.model_management import BaseModelType from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - Input, - InputField, - InvocationContext, - OutputField, - WithMetadata, invocation, invocation_output, ) @@ -57,10 +59,7 @@ class ControlNetModelField(BaseModel): """ControlNet model field""" - model_name: str = Field(description="Name of the ControlNet model") - base_model: BaseModelType = Field(description="Base model") - - model_config = ConfigDict(protected_namespaces=()) + key: str = Field(description="Model config record key for the ControlNet model") class ControlField(BaseModel): @@ -140,7 +139,7 @@ def invoke(self, context: InvocationContext) -> ControlOutput: # This invocation exists for other invocations to subclass it - do not register with @invocation! -class ImageProcessorInvocation(BaseInvocation, WithMetadata): +class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard): """Base class for invocations that preprocess images for ControlNet""" image: ImageField = InputField(description="The image to process") @@ -150,22 +149,13 @@ def run_processor(self, image: Image.Image) -> Image.Image: return image def invoke(self, context: InvocationContext) -> ImageOutput: - raw_image = context.services.images.get_pil_image(self.image.image_name) + raw_image = context.images.get_pil(self.image.image_name) # image type should be PIL.PngImagePlugin.PngImageFile ? processed_image = self.run_processor(raw_image) # currently can't see processed image in node UI without a showImage node, # so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery - image_dto = context.services.images.create( - image=processed_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.CONTROL, - session_id=context.graph_execution_state_id, - node_id=self.id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=processed_image) """Builds an ImageOutput and its ImageField""" processed_image_field = ImageField(image_name=image_dto.image_name) @@ -184,7 +174,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Canny Processor", tags=["controlnet", "canny"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class CannyImageProcessorInvocation(ImageProcessorInvocation): """Canny edge detection for ControlNet""" @@ -207,7 +197,7 @@ def run_processor(self, image): title="HED (softedge) Processor", tags=["controlnet", "hed", "softedge"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class HedImageProcessorInvocation(ImageProcessorInvocation): """Applies HED edge detection to image""" @@ -236,7 +226,7 @@ def run_processor(self, image): title="Lineart Processor", tags=["controlnet", "lineart"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class LineartImageProcessorInvocation(ImageProcessorInvocation): """Applies line art processing to image""" @@ -258,7 +248,7 @@ def run_processor(self, image): title="Lineart Anime Processor", tags=["controlnet", "lineart", "anime"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation): """Applies line art anime processing to image""" @@ -281,7 +271,7 @@ def run_processor(self, image): title="Midas Depth Processor", tags=["controlnet", "midas"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class MidasDepthImageProcessorInvocation(ImageProcessorInvocation): """Applies Midas depth processing to image""" @@ -308,7 +298,7 @@ def run_processor(self, image): title="Normal BAE Processor", tags=["controlnet"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class NormalbaeImageProcessorInvocation(ImageProcessorInvocation): """Applies NormalBae processing to image""" @@ -325,7 +315,7 @@ def run_processor(self, image): @invocation( - "mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.2.0" + "mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.2.1" ) class MlsdImageProcessorInvocation(ImageProcessorInvocation): """Applies MLSD processing to image""" @@ -348,7 +338,7 @@ def run_processor(self, image): @invocation( - "pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.2.0" + "pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.2.1" ) class PidiImageProcessorInvocation(ImageProcessorInvocation): """Applies PIDI processing to image""" @@ -375,7 +365,7 @@ def run_processor(self, image): title="Content Shuffle Processor", tags=["controlnet", "contentshuffle"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation): """Applies content shuffle processing to image""" @@ -405,7 +395,7 @@ def run_processor(self, image): title="Zoe (Depth) Processor", tags=["controlnet", "zoe", "depth"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation): """Applies Zoe depth processing to image""" @@ -421,7 +411,7 @@ def run_processor(self, image): title="Mediapipe Face Processor", tags=["controlnet", "mediapipe", "face"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class MediapipeFaceProcessorInvocation(ImageProcessorInvocation): """Applies mediapipe face processing to image""" @@ -444,7 +434,7 @@ def run_processor(self, image): title="Leres (Depth) Processor", tags=["controlnet", "leres", "depth"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class LeresImageProcessorInvocation(ImageProcessorInvocation): """Applies leres processing to image""" @@ -473,7 +463,7 @@ def run_processor(self, image): title="Tile Resample Processor", tags=["controlnet", "tile"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class TileResamplerProcessorInvocation(ImageProcessorInvocation): """Tile resampler processor""" @@ -513,7 +503,7 @@ def run_processor(self, img): title="Segment Anything Processor", tags=["controlnet", "segmentanything"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class SegmentAnythingProcessorInvocation(ImageProcessorInvocation): """Applies segment anything processing to image""" @@ -555,7 +545,7 @@ def show_anns(self, anns: List[Dict]): title="Color Map Processor", tags=["controlnet"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class ColorMapImageProcessorInvocation(ImageProcessorInvocation): """Generates a color map from the provided image""" diff --git a/invokeai/app/invocations/cv.py b/invokeai/app/invocations/cv.py index cb6828d21ac..8174f19b64c 100644 --- a/invokeai/app/invocations/cv.py +++ b/invokeai/app/invocations/cv.py @@ -5,22 +5,24 @@ import numpy from PIL import Image, ImageOps -from invokeai.app.invocations.primitives import ImageField, ImageOutput -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin +from invokeai.app.invocations.fields import ImageField +from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.shared.invocation_context import InvocationContext -from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, invocation +from .baseinvocation import BaseInvocation, invocation +from .fields import InputField, WithBoard, WithMetadata -@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.2.0") -class CvInpaintInvocation(BaseInvocation, WithMetadata): +@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.2.1") +class CvInpaintInvocation(BaseInvocation, WithMetadata, WithBoard): """Simple inpaint using opencv.""" image: ImageField = InputField(description="The image to inpaint") mask: ImageField = InputField(description="The mask to use when inpainting") def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) - mask = context.services.images.get_pil_image(self.mask.image_name) + image = context.images.get_pil(self.image.image_name) + mask = context.images.get_pil(self.mask.image_name) # Convert to cv image/mask # TODO: consider making these utility functions @@ -34,18 +36,6 @@ def invoke(self, context: InvocationContext) -> ImageOutput: # TODO: consider making a utility function image_inpainted = Image.fromarray(cv.cvtColor(cv_inpainted, cv.COLOR_BGR2RGB)) - image_dto = context.services.images.create( - image=image_inpainted, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - workflow=context.workflow, - ) - - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + image_dto = context.images.save(image=image_inpainted) + + return ImageOutput.build(image_dto) diff --git a/invokeai/app/invocations/facetools.py b/invokeai/app/invocations/facetools.py index e0c89b4de5a..fed2ed5e4f2 100644 --- a/invokeai/app/invocations/facetools.py +++ b/invokeai/app/invocations/facetools.py @@ -13,15 +13,13 @@ import invokeai.assets.fonts as font_assets from invokeai.app.invocations.baseinvocation import ( BaseInvocation, - InputField, - InvocationContext, - OutputField, - WithMetadata, invocation, invocation_output, ) -from invokeai.app.invocations.primitives import ImageField, ImageOutput -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin +from invokeai.app.invocations.fields import ImageField, InputField, OutputField, WithBoard, WithMetadata +from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.image_records.image_records_common import ImageCategory +from invokeai.app.services.shared.invocation_context import InvocationContext @invocation_output("face_mask_output") @@ -306,37 +304,37 @@ def extract_face( # Adjust the crop boundaries to stay within the original image's dimensions if x_min < 0: - context.services.logger.warning("FaceTools --> -X-axis padding reached image edge.") + context.logger.warning("FaceTools --> -X-axis padding reached image edge.") x_max -= x_min x_min = 0 elif x_max > mask.width: - context.services.logger.warning("FaceTools --> +X-axis padding reached image edge.") + context.logger.warning("FaceTools --> +X-axis padding reached image edge.") x_min -= x_max - mask.width x_max = mask.width if y_min < 0: - context.services.logger.warning("FaceTools --> +Y-axis padding reached image edge.") + context.logger.warning("FaceTools --> +Y-axis padding reached image edge.") y_max -= y_min y_min = 0 elif y_max > mask.height: - context.services.logger.warning("FaceTools --> -Y-axis padding reached image edge.") + context.logger.warning("FaceTools --> -Y-axis padding reached image edge.") y_min -= y_max - mask.height y_max = mask.height # Ensure the crop is square and adjust the boundaries if needed if x_max - x_min != crop_size: - context.services.logger.warning("FaceTools --> Limiting x-axis padding to constrain bounding box to a square.") + context.logger.warning("FaceTools --> Limiting x-axis padding to constrain bounding box to a square.") diff = crop_size - (x_max - x_min) x_min -= diff // 2 x_max += diff - diff // 2 if y_max - y_min != crop_size: - context.services.logger.warning("FaceTools --> Limiting y-axis padding to constrain bounding box to a square.") + context.logger.warning("FaceTools --> Limiting y-axis padding to constrain bounding box to a square.") diff = crop_size - (y_max - y_min) y_min -= diff // 2 y_max += diff - diff // 2 - context.services.logger.info(f"FaceTools --> Calculated bounding box (8 multiple): {crop_size}") + context.logger.info(f"FaceTools --> Calculated bounding box (8 multiple): {crop_size}") # Crop the output image to the specified size with the center of the face mesh as the center. mask = mask.crop((x_min, y_min, x_max, y_max)) @@ -368,7 +366,7 @@ def get_faces_list( # Generate the face box mask and get the center of the face. if not should_chunk: - context.services.logger.info("FaceTools --> Attempting full image face detection.") + context.logger.info("FaceTools --> Attempting full image face detection.") result = generate_face_box_mask( context=context, minimum_confidence=minimum_confidence, @@ -380,7 +378,7 @@ def get_faces_list( draw_mesh=draw_mesh, ) if should_chunk or len(result) == 0: - context.services.logger.info("FaceTools --> Chunking image (chunk toggled on, or no face found in full image).") + context.logger.info("FaceTools --> Chunking image (chunk toggled on, or no face found in full image).") width, height = image.size image_chunks = [] x_offsets = [] @@ -399,7 +397,7 @@ def get_faces_list( x_offsets.append(x) y_offsets.append(0) fx += increment - context.services.logger.info(f"FaceTools --> Chunk starting at x = {x}") + context.logger.info(f"FaceTools --> Chunk starting at x = {x}") elif height > width: # Portrait - slice the image vertically fy = 0.0 @@ -411,10 +409,10 @@ def get_faces_list( x_offsets.append(0) y_offsets.append(y) fy += increment - context.services.logger.info(f"FaceTools --> Chunk starting at y = {y}") + context.logger.info(f"FaceTools --> Chunk starting at y = {y}") for idx in range(len(image_chunks)): - context.services.logger.info(f"FaceTools --> Evaluating faces in chunk {idx}") + context.logger.info(f"FaceTools --> Evaluating faces in chunk {idx}") result = result + generate_face_box_mask( context=context, minimum_confidence=minimum_confidence, @@ -428,7 +426,7 @@ def get_faces_list( if len(result) == 0: # Give up - context.services.logger.warning( + context.logger.warning( "FaceTools --> No face detected in chunked input image. Passing through original image." ) @@ -437,7 +435,7 @@ def get_faces_list( return all_faces -@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.2.0") +@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.2.1") class FaceOffInvocation(BaseInvocation, WithMetadata): """Bound, extract, and mask a face from an image using MediaPipe detection""" @@ -470,11 +468,11 @@ def faceoff(self, context: InvocationContext, image: ImageType) -> Optional[Extr ) if len(all_faces) == 0: - context.services.logger.warning("FaceOff --> No faces detected. Passing through original image.") + context.logger.warning("FaceOff --> No faces detected. Passing through original image.") return None if self.face_id > len(all_faces) - 1: - context.services.logger.warning( + context.logger.warning( f"FaceOff --> Face ID {self.face_id} is outside of the number of faces detected ({len(all_faces)}). Passing through original image." ) return None @@ -486,7 +484,7 @@ def faceoff(self, context: InvocationContext, image: ImageType) -> Optional[Extr return face_data def invoke(self, context: InvocationContext) -> FaceOffOutput: - image = context.services.images.get_pil_image(self.image.image_name) + image = context.images.get_pil(self.image.image_name) result = self.faceoff(context=context, image=image) if result is None: @@ -500,24 +498,9 @@ def invoke(self, context: InvocationContext) -> FaceOffOutput: x = result["x_min"] y = result["y_min"] - image_dto = context.services.images.create( - image=result_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - workflow=context.workflow, - ) + image_dto = context.images.save(image=result_image) - mask_dto = context.services.images.create( - image=result_mask, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.MASK, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - ) + mask_dto = context.images.save(image=result_mask, image_category=ImageCategory.MASK) output = FaceOffOutput( image=ImageField(image_name=image_dto.image_name), @@ -531,7 +514,7 @@ def invoke(self, context: InvocationContext) -> FaceOffOutput: return output -@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.2.0") +@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.2.1") class FaceMaskInvocation(BaseInvocation, WithMetadata): """Face mask creation using mediapipe face detection""" @@ -580,7 +563,7 @@ def facemask(self, context: InvocationContext, image: ImageType) -> FaceMaskResu if len(intersected_face_ids) == 0: id_range_str = ",".join([str(id) for id in id_range]) - context.services.logger.warning( + context.logger.warning( f"Face IDs must be in range of detected faces - requested {self.face_ids}, detected {id_range_str}. Passing through original image." ) return FaceMaskResult( @@ -616,27 +599,12 @@ def facemask(self, context: InvocationContext, image: ImageType) -> FaceMaskResu ) def invoke(self, context: InvocationContext) -> FaceMaskOutput: - image = context.services.images.get_pil_image(self.image.image_name) + image = context.images.get_pil(self.image.image_name) result = self.facemask(context=context, image=image) - image_dto = context.services.images.create( - image=result["image"], - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - workflow=context.workflow, - ) + image_dto = context.images.save(image=result["image"]) - mask_dto = context.services.images.create( - image=result["mask"], - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.MASK, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - ) + mask_dto = context.images.save(image=result["mask"], image_category=ImageCategory.MASK) output = FaceMaskOutput( image=ImageField(image_name=image_dto.image_name), @@ -649,9 +617,9 @@ def invoke(self, context: InvocationContext) -> FaceMaskOutput: @invocation( - "face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.2.0" + "face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.2.1" ) -class FaceIdentifierInvocation(BaseInvocation, WithMetadata): +class FaceIdentifierInvocation(BaseInvocation, WithMetadata, WithBoard): """Outputs an image with detected face IDs printed on each face. For use with other FaceTools.""" image: ImageField = InputField(description="Image to face detect") @@ -705,21 +673,9 @@ def faceidentifier(self, context: InvocationContext, image: ImageType) -> ImageT return image def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + image = context.images.get_pil(self.image.image_name) result_image = self.faceidentifier(context=context, image=image) - image_dto = context.services.images.create( - image=result_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - workflow=context.workflow, - ) + image_dto = context.images.save(image=result_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py new file mode 100644 index 00000000000..40d403c03d9 --- /dev/null +++ b/invokeai/app/invocations/fields.py @@ -0,0 +1,565 @@ +from enum import Enum +from typing import Any, Callable, Optional, Tuple + +from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter +from pydantic.fields import _Unset +from pydantic_core import PydanticUndefined + +from invokeai.app.util.metaenum import MetaEnum +from invokeai.backend.util.logging import InvokeAILogger + +logger = InvokeAILogger.get_logger() + + +class UIType(str, Enum, metaclass=MetaEnum): + """ + Type hints for the UI for situations in which the field type is not enough to infer the correct UI type. + + - Model Fields + The most common node-author-facing use will be for model fields. Internally, there is no difference + between SD-1, SD-2 and SDXL model fields - they all use the class `MainModelField`. To ensure the + base-model-specific UI is rendered, use e.g. `ui_type=UIType.SDXLMainModelField` to indicate that + the field is an SDXL main model field. + + - Any Field + We cannot infer the usage of `typing.Any` via schema parsing, so you *must* use `ui_type=UIType.Any` to + indicate that the field accepts any type. Use with caution. This cannot be used on outputs. + + - Scheduler Field + Special handling in the UI is needed for this field, which otherwise would be parsed as a plain enum field. + + - Internal Fields + Similar to the Any Field, the `collect` and `iterate` nodes make use of `typing.Any`. To facilitate + handling these types in the client, we use `UIType._Collection` and `UIType._CollectionItem`. These + should not be used by node authors. + + - DEPRECATED Fields + These types are deprecated and should not be used by node authors. A warning will be logged if one is + used, and the type will be ignored. They are included here for backwards compatibility. + """ + + # region Model Field Types + SDXLMainModel = "SDXLMainModelField" + SDXLRefinerModel = "SDXLRefinerModelField" + ONNXModel = "ONNXModelField" + VaeModel = "VAEModelField" + LoRAModel = "LoRAModelField" + ControlNetModel = "ControlNetModelField" + IPAdapterModel = "IPAdapterModelField" + # endregion + + # region Misc Field Types + Scheduler = "SchedulerField" + Any = "AnyField" + # endregion + + # region Internal Field Types + _Collection = "CollectionField" + _CollectionItem = "CollectionItemField" + # endregion + + # region DEPRECATED + Boolean = "DEPRECATED_Boolean" + Color = "DEPRECATED_Color" + Conditioning = "DEPRECATED_Conditioning" + Control = "DEPRECATED_Control" + Float = "DEPRECATED_Float" + Image = "DEPRECATED_Image" + Integer = "DEPRECATED_Integer" + Latents = "DEPRECATED_Latents" + String = "DEPRECATED_String" + BooleanCollection = "DEPRECATED_BooleanCollection" + ColorCollection = "DEPRECATED_ColorCollection" + ConditioningCollection = "DEPRECATED_ConditioningCollection" + ControlCollection = "DEPRECATED_ControlCollection" + FloatCollection = "DEPRECATED_FloatCollection" + ImageCollection = "DEPRECATED_ImageCollection" + IntegerCollection = "DEPRECATED_IntegerCollection" + LatentsCollection = "DEPRECATED_LatentsCollection" + StringCollection = "DEPRECATED_StringCollection" + BooleanPolymorphic = "DEPRECATED_BooleanPolymorphic" + ColorPolymorphic = "DEPRECATED_ColorPolymorphic" + ConditioningPolymorphic = "DEPRECATED_ConditioningPolymorphic" + ControlPolymorphic = "DEPRECATED_ControlPolymorphic" + FloatPolymorphic = "DEPRECATED_FloatPolymorphic" + ImagePolymorphic = "DEPRECATED_ImagePolymorphic" + IntegerPolymorphic = "DEPRECATED_IntegerPolymorphic" + LatentsPolymorphic = "DEPRECATED_LatentsPolymorphic" + StringPolymorphic = "DEPRECATED_StringPolymorphic" + MainModel = "DEPRECATED_MainModel" + UNet = "DEPRECATED_UNet" + Vae = "DEPRECATED_Vae" + CLIP = "DEPRECATED_CLIP" + Collection = "DEPRECATED_Collection" + CollectionItem = "DEPRECATED_CollectionItem" + Enum = "DEPRECATED_Enum" + WorkflowField = "DEPRECATED_WorkflowField" + IsIntermediate = "DEPRECATED_IsIntermediate" + BoardField = "DEPRECATED_BoardField" + MetadataItem = "DEPRECATED_MetadataItem" + MetadataItemCollection = "DEPRECATED_MetadataItemCollection" + MetadataItemPolymorphic = "DEPRECATED_MetadataItemPolymorphic" + MetadataDict = "DEPRECATED_MetadataDict" + + +class UIComponent(str, Enum, metaclass=MetaEnum): + """ + The type of UI component to use for a field, used to override the default components, which are + inferred from the field type. + """ + + None_ = "none" + Textarea = "textarea" + Slider = "slider" + + +class FieldDescriptions: + denoising_start = "When to start denoising, expressed a percentage of total steps" + denoising_end = "When to stop denoising, expressed a percentage of total steps" + cfg_scale = "Classifier-Free Guidance scale" + cfg_rescale_multiplier = "Rescale multiplier for CFG guidance, used for models trained with zero-terminal SNR" + scheduler = "Scheduler to use during inference" + positive_cond = "Positive conditioning tensor" + negative_cond = "Negative conditioning tensor" + noise = "Noise tensor" + clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count" + unet = "UNet (scheduler, LoRAs)" + vae = "VAE" + cond = "Conditioning tensor" + controlnet_model = "ControlNet model to load" + vae_model = "VAE model to load" + lora_model = "LoRA model to load" + main_model = "Main model (UNet, VAE, CLIP) to load" + sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load" + sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load" + onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load" + lora_weight = "The weight at which the LoRA is applied to each model" + compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor" + raw_prompt = "Raw prompt text (no parsing)" + sdxl_aesthetic = "The aesthetic score to apply to the conditioning tensor" + skipped_layers = "Number of layers to skip in text encoder" + seed = "Seed for random number generation" + steps = "Number of steps to run" + width = "Width of output (px)" + height = "Height of output (px)" + control = "ControlNet(s) to apply" + ip_adapter = "IP-Adapter to apply" + t2i_adapter = "T2I-Adapter(s) to apply" + denoised_latents = "Denoised latents tensor" + latents = "Latents tensor" + strength = "Strength of denoising (proportional to steps)" + metadata = "Optional metadata to be saved with the image" + metadata_collection = "Collection of Metadata" + metadata_item_polymorphic = "A single metadata item or collection of metadata items" + metadata_item_label = "Label for this metadata item" + metadata_item_value = "The value for this metadata item (may be any type)" + workflow = "Optional workflow to be saved with the image" + interp_mode = "Interpolation mode" + torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)" + fp32 = "Whether or not to use full float32 precision" + precision = "Precision to use" + tiled = "Processing using overlapping tiles (reduce memory consumption)" + detect_res = "Pixel resolution for detection" + image_res = "Pixel resolution for output image" + safe_mode = "Whether or not to use safe mode" + scribble_mode = "Whether or not to use scribble mode" + scale_factor = "The factor by which to scale" + blend_alpha = ( + "Blending factor. 0.0 = use input A only, 1.0 = use input B only, 0.5 = 50% mix of input A and input B." + ) + num_1 = "The first number" + num_2 = "The second number" + mask = "The mask to use for the operation" + board = "The board to save the image to" + image = "The image to process" + tile_size = "Tile size" + inclusive_low = "The inclusive low value" + exclusive_high = "The exclusive high value" + decimal_places = "The number of decimal places to round to" + freeu_s1 = 'Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.' + freeu_s2 = 'Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.' + freeu_b1 = "Scaling factor for stage 1 to amplify the contributions of backbone features." + freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features." + + +class ImageField(BaseModel): + """An image primitive field""" + + image_name: str = Field(description="The name of the image") + + +class BoardField(BaseModel): + """A board primitive field""" + + board_id: str = Field(description="The id of the board") + + +class DenoiseMaskField(BaseModel): + """An inpaint mask field""" + + mask_name: str = Field(description="The name of the mask image") + masked_latents_name: Optional[str] = Field(default=None, description="The name of the masked image latents") + + +class LatentsField(BaseModel): + """A latents tensor primitive field""" + + latents_name: str = Field(description="The name of the latents") + seed: Optional[int] = Field(default=None, description="Seed used to generate this latents") + + +class ColorField(BaseModel): + """A color primitive field""" + + r: int = Field(ge=0, le=255, description="The red component") + g: int = Field(ge=0, le=255, description="The green component") + b: int = Field(ge=0, le=255, description="The blue component") + a: int = Field(ge=0, le=255, description="The alpha component") + + def tuple(self) -> Tuple[int, int, int, int]: + return (self.r, self.g, self.b, self.a) + + +class ConditioningField(BaseModel): + """A conditioning tensor primitive value""" + + conditioning_name: str = Field(description="The name of conditioning tensor") + # endregion + + +class MetadataField(RootModel): + """ + Pydantic model for metadata with custom root of type dict[str, Any]. + Metadata is stored without a strict schema. + """ + + root: dict[str, Any] = Field(description="The metadata") + + +MetadataFieldValidator = TypeAdapter(MetadataField) + + +class Input(str, Enum, metaclass=MetaEnum): + """ + The type of input a field accepts. + - `Input.Direct`: The field must have its value provided directly, when the invocation and field \ + are instantiated. + - `Input.Connection`: The field must have its value provided by a connection. + - `Input.Any`: The field may have its value provided either directly or by a connection. + """ + + Connection = "connection" + Direct = "direct" + Any = "any" + + +class FieldKind(str, Enum, metaclass=MetaEnum): + """ + The kind of field. + - `Input`: An input field on a node. + - `Output`: An output field on a node. + - `Internal`: A field which is treated as an input, but cannot be used in node definitions. Metadata is + one example. It is provided to nodes via the WithMetadata class, and we want to reserve the field name + "metadata" for this on all nodes. `FieldKind` is used to short-circuit the field name validation logic, + allowing "metadata" for that field. + - `NodeAttribute`: The field is a node attribute. These are fields which are not inputs or outputs, + but which are used to store information about the node. For example, the `id` and `type` fields are node + attributes. + + The presence of this in `json_schema_extra["field_kind"]` is used when initializing node schemas on app + startup, and when generating the OpenAPI schema for the workflow editor. + """ + + Input = "input" + Output = "output" + Internal = "internal" + NodeAttribute = "node_attribute" + + +class InputFieldJSONSchemaExtra(BaseModel): + """ + Extra attributes to be added to input fields and their OpenAPI schema. Used during graph execution, + and by the workflow editor during schema parsing and UI rendering. + """ + + input: Input + orig_required: bool + field_kind: FieldKind + default: Optional[Any] = None + orig_default: Optional[Any] = None + ui_hidden: bool = False + ui_type: Optional[UIType] = None + ui_component: Optional[UIComponent] = None + ui_order: Optional[int] = None + ui_choice_labels: Optional[dict[str, str]] = None + + model_config = ConfigDict( + validate_assignment=True, + json_schema_serialization_defaults_required=True, + ) + + +class WithMetadata(BaseModel): + """ + Inherit from this class if your node needs a metadata input field. + """ + + metadata: Optional[MetadataField] = Field( + default=None, + description=FieldDescriptions.metadata, + json_schema_extra=InputFieldJSONSchemaExtra( + field_kind=FieldKind.Internal, + input=Input.Connection, + orig_required=False, + ).model_dump(exclude_none=True), + ) + + +class WithWorkflow: + workflow = None + + def __init_subclass__(cls) -> None: + logger.warn( + f"{cls.__module__.split('.')[0]}.{cls.__name__}: WithWorkflow is deprecated. Use `context.workflow` to access the workflow." + ) + super().__init_subclass__() + + +class WithBoard(BaseModel): + """ + Inherit from this class if your node needs a board input field. + """ + + board: Optional[BoardField] = Field( + default=None, + description=FieldDescriptions.board, + json_schema_extra=InputFieldJSONSchemaExtra( + field_kind=FieldKind.Internal, + input=Input.Direct, + orig_required=False, + ).model_dump(exclude_none=True), + ) + + +class OutputFieldJSONSchemaExtra(BaseModel): + """ + Extra attributes to be added to input fields and their OpenAPI schema. Used by the workflow editor + during schema parsing and UI rendering. + """ + + field_kind: FieldKind + ui_hidden: bool + ui_type: Optional[UIType] + ui_order: Optional[int] + + model_config = ConfigDict( + validate_assignment=True, + json_schema_serialization_defaults_required=True, + ) + + +def InputField( + # copied from pydantic's Field + # TODO: Can we support default_factory? + default: Any = _Unset, + default_factory: Callable[[], Any] | None = _Unset, + title: str | None = _Unset, + description: str | None = _Unset, + pattern: str | None = _Unset, + strict: bool | None = _Unset, + gt: float | None = _Unset, + ge: float | None = _Unset, + lt: float | None = _Unset, + le: float | None = _Unset, + multiple_of: float | None = _Unset, + allow_inf_nan: bool | None = _Unset, + max_digits: int | None = _Unset, + decimal_places: int | None = _Unset, + min_length: int | None = _Unset, + max_length: int | None = _Unset, + # custom + input: Input = Input.Any, + ui_type: Optional[UIType] = None, + ui_component: Optional[UIComponent] = None, + ui_hidden: bool = False, + ui_order: Optional[int] = None, + ui_choice_labels: Optional[dict[str, str]] = None, +) -> Any: + """ + Creates an input field for an invocation. + + This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/latest/api/fields/#pydantic.fields.Field) \ + that adds a few extra parameters to support graph execution and the node editor UI. + + :param Input input: [Input.Any] The kind of input this field requires. \ + `Input.Direct` means a value must be provided on instantiation. \ + `Input.Connection` means the value must be provided by a connection. \ + `Input.Any` means either will do. + + :param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \ + In some situations, the field's type is not enough to infer the correct UI type. \ + For example, model selection fields should render a dropdown UI component to select a model. \ + Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \ + `MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \ + `UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field. + + :param UIComponent ui_component: [None] Optionally specifies a specific component to use in the UI. \ + The UI will always render a suitable component, but sometimes you want something different than the default. \ + For example, a `string` field will default to a single-line input, but you may want a multi-line textarea instead. \ + For this case, you could provide `UIComponent.Textarea`. + + :param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. + + :param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. + + :param dict[str, str] ui_choice_labels: [None] Specifies the labels to use for the choices in an enum field. + """ + + json_schema_extra_ = InputFieldJSONSchemaExtra( + input=input, + ui_type=ui_type, + ui_component=ui_component, + ui_hidden=ui_hidden, + ui_order=ui_order, + ui_choice_labels=ui_choice_labels, + field_kind=FieldKind.Input, + orig_required=True, + ) + + """ + There is a conflict between the typing of invocation definitions and the typing of an invocation's + `invoke()` function. + + On instantiation of a node, the invocation definition is used to create the python class. At this time, + any number of fields may be optional, because they may be provided by connections. + + On calling of `invoke()`, however, those fields may be required. + + For example, consider an ResizeImageInvocation with an `image: ImageField` field. + + `image` is required during the call to `invoke()`, but when the python class is instantiated, + the field may not be present. This is fine, because that image field will be provided by a + connection from an ancestor node, which outputs an image. + + This means we want to type the `image` field as optional for the node class definition, but required + for the `invoke()` function. + + If we use `typing.Optional` in the node class definition, the field will be typed as optional in the + `invoke()` method, and we'll have to do a lot of runtime checks to ensure the field is present - or + any static type analysis tools will complain. + + To get around this, in node class definitions, we type all fields correctly for the `invoke()` function, + but secretly make them optional in `InputField()`. We also store the original required bool and/or default + value. When we call `invoke()`, we use this stored information to do an additional check on the class. + """ + + if default_factory is not _Unset and default_factory is not None: + default = default_factory() + logger.warn('"default_factory" is not supported, calling it now to set "default"') + + # These are the args we may wish pass to the pydantic `Field()` function + field_args = { + "default": default, + "title": title, + "description": description, + "pattern": pattern, + "strict": strict, + "gt": gt, + "ge": ge, + "lt": lt, + "le": le, + "multiple_of": multiple_of, + "allow_inf_nan": allow_inf_nan, + "max_digits": max_digits, + "decimal_places": decimal_places, + "min_length": min_length, + "max_length": max_length, + } + + # We only want to pass the args that were provided, otherwise the `Field()`` function won't work as expected + provided_args = {k: v for (k, v) in field_args.items() if v is not PydanticUndefined} + + # Because we are manually making fields optional, we need to store the original required bool for reference later + json_schema_extra_.orig_required = default is PydanticUndefined + + # Make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one + if input is Input.Any or input is Input.Connection: + default_ = None if default is PydanticUndefined else default + provided_args.update({"default": default_}) + if default is not PydanticUndefined: + # Before invoking, we'll check for the original default value and set it on the field if the field has no value + json_schema_extra_.default = default + json_schema_extra_.orig_default = default + elif default is not PydanticUndefined: + default_ = default + provided_args.update({"default": default_}) + json_schema_extra_.orig_default = default_ + + return Field( + **provided_args, + json_schema_extra=json_schema_extra_.model_dump(exclude_none=True), + ) + + +def OutputField( + # copied from pydantic's Field + default: Any = _Unset, + title: str | None = _Unset, + description: str | None = _Unset, + pattern: str | None = _Unset, + strict: bool | None = _Unset, + gt: float | None = _Unset, + ge: float | None = _Unset, + lt: float | None = _Unset, + le: float | None = _Unset, + multiple_of: float | None = _Unset, + allow_inf_nan: bool | None = _Unset, + max_digits: int | None = _Unset, + decimal_places: int | None = _Unset, + min_length: int | None = _Unset, + max_length: int | None = _Unset, + # custom + ui_type: Optional[UIType] = None, + ui_hidden: bool = False, + ui_order: Optional[int] = None, +) -> Any: + """ + Creates an output field for an invocation output. + + This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \ + that adds a few extra parameters to support graph execution and the node editor UI. + + :param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \ + In some situations, the field's type is not enough to infer the correct UI type. \ + For example, model selection fields should render a dropdown UI component to select a model. \ + Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \ + `MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \ + `UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field. + + :param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \ + + :param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \ + """ + return Field( + default=default, + title=title, + description=description, + pattern=pattern, + strict=strict, + gt=gt, + ge=ge, + lt=lt, + le=le, + multiple_of=multiple_of, + allow_inf_nan=allow_inf_nan, + max_digits=max_digits, + decimal_places=decimal_places, + min_length=min_length, + max_length=max_length, + json_schema_extra=OutputFieldJSONSchemaExtra( + ui_type=ui_type, + ui_hidden=ui_hidden, + ui_order=ui_order, + field_kind=FieldKind.Output, + ).model_dump(exclude_none=True), + ) diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index f729d60cdd5..f5ad5515a68 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -7,33 +7,36 @@ import numpy from PIL import Image, ImageChops, ImageFilter, ImageOps -from invokeai.app.invocations.primitives import BoardField, ColorField, ImageField, ImageOutput -from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin -from invokeai.app.shared.fields import FieldDescriptions +from invokeai.app.invocations.fields import ( + ColorField, + FieldDescriptions, + ImageField, + InputField, + WithBoard, + WithMetadata, +) +from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.image_records.image_records_common import ImageCategory +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark from invokeai.backend.image_util.safety_checker import SafetyChecker from .baseinvocation import ( BaseInvocation, Classification, - Input, - InputField, - InvocationContext, - WithMetadata, invocation, ) -@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.0") +@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.1") class ShowImageInvocation(BaseInvocation): """Displays a provided image using the OS image viewer, and passes it forward in the pipeline.""" image: ImageField = InputField(description="The image to show") def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) - if image: - image.show() + image = context.images.get_pil(self.image.image_name) + image.show() # TODO: how to handle failure? @@ -49,9 +52,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Blank Image", tags=["image"], category="image", - version="1.2.0", + version="1.2.1", ) -class BlankImageInvocation(BaseInvocation, WithMetadata): +class BlankImageInvocation(BaseInvocation, WithMetadata, WithBoard): """Creates a blank image and forwards it to the pipeline""" width: int = InputField(default=512, description="The width of the image") @@ -62,22 +65,9 @@ class BlankImageInvocation(BaseInvocation, WithMetadata): def invoke(self, context: InvocationContext) -> ImageOutput: image = Image.new(mode=self.mode, size=(self.width, self.height), color=self.color.tuple()) - image_dto = context.services.images.create( - image=image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -85,9 +75,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Crop Image", tags=["image", "crop"], category="image", - version="1.2.0", + version="1.2.1", ) -class ImageCropInvocation(BaseInvocation, WithMetadata): +class ImageCropInvocation(BaseInvocation, WithMetadata, WithBoard): """Crops an image to a specified box. The box can be outside of the image.""" image: ImageField = InputField(description="The image to crop") @@ -97,27 +87,14 @@ class ImageCropInvocation(BaseInvocation, WithMetadata): height: int = InputField(default=512, gt=0, description="The height of the crop rectangle") def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + image = context.images.get_pil(self.image.image_name) image_crop = Image.new(mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0)) image_crop.paste(image, (-self.x, -self.y)) - image_dto = context.services.images.create( - image=image_crop, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=image_crop) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -149,7 +126,7 @@ class CenterPadCropInvocation(BaseInvocation): ) def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + image = context.images.get_pil(self.image.image_name) # Calculate and create new image dimensions new_width = image.width + self.right + self.left @@ -159,20 +136,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: # Paste new image onto input image_crop.paste(image, (self.left, self.top)) - image_dto = context.services.images.create( - image=image_crop, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - ) + image_dto = context.images.save(image=image_crop) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -180,9 +146,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Paste Image", tags=["image", "paste"], category="image", - version="1.2.0", + version="1.2.1", ) -class ImagePasteInvocation(BaseInvocation, WithMetadata): +class ImagePasteInvocation(BaseInvocation, WithMetadata, WithBoard): """Pastes an image into another image.""" base_image: ImageField = InputField(description="The base image") @@ -196,11 +162,11 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata): crop: bool = InputField(default=False, description="Crop to base image dimensions") def invoke(self, context: InvocationContext) -> ImageOutput: - base_image = context.services.images.get_pil_image(self.base_image.image_name) - image = context.services.images.get_pil_image(self.image.image_name) + base_image = context.images.get_pil(self.base_image.image_name) + image = context.images.get_pil(self.image.image_name) mask = None if self.mask is not None: - mask = context.services.images.get_pil_image(self.mask.image_name) + mask = context.images.get_pil(self.mask.image_name) mask = ImageOps.invert(mask.convert("L")) # TODO: probably shouldn't invert mask here... should user be required to do it? @@ -217,22 +183,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: base_w, base_h = base_image.size new_image = new_image.crop((abs(min_x), abs(min_y), abs(min_x) + base_w, abs(min_y) + base_h)) - image_dto = context.services.images.create( - image=new_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=new_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -240,37 +193,24 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Mask from Alpha", tags=["image", "mask"], category="image", - version="1.2.0", + version="1.2.1", ) -class MaskFromAlphaInvocation(BaseInvocation, WithMetadata): +class MaskFromAlphaInvocation(BaseInvocation, WithMetadata, WithBoard): """Extracts the alpha channel of an image as a mask.""" image: ImageField = InputField(description="The image to create the mask from") invert: bool = InputField(default=False, description="Whether or not to invert the mask") def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + image = context.images.get_pil(self.image.image_name) image_mask = image.split()[-1] if self.invert: image_mask = ImageOps.invert(image_mask) - image_dto = context.services.images.create( - image=image_mask, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.MASK, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=image_mask, image_category=ImageCategory.MASK) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -278,36 +218,23 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Multiply Images", tags=["image", "multiply"], category="image", - version="1.2.0", + version="1.2.1", ) -class ImageMultiplyInvocation(BaseInvocation, WithMetadata): +class ImageMultiplyInvocation(BaseInvocation, WithMetadata, WithBoard): """Multiplies two images together using `PIL.ImageChops.multiply()`.""" image1: ImageField = InputField(description="The first image to multiply") image2: ImageField = InputField(description="The second image to multiply") def invoke(self, context: InvocationContext) -> ImageOutput: - image1 = context.services.images.get_pil_image(self.image1.image_name) - image2 = context.services.images.get_pil_image(self.image2.image_name) + image1 = context.images.get_pil(self.image1.image_name) + image2 = context.images.get_pil(self.image2.image_name) multiply_image = ImageChops.multiply(image1, image2) - image_dto = context.services.images.create( - image=multiply_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=multiply_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) IMAGE_CHANNELS = Literal["A", "R", "G", "B"] @@ -318,35 +245,22 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Extract Image Channel", tags=["image", "channel"], category="image", - version="1.2.0", + version="1.2.1", ) -class ImageChannelInvocation(BaseInvocation, WithMetadata): +class ImageChannelInvocation(BaseInvocation, WithMetadata, WithBoard): """Gets a channel from an image.""" image: ImageField = InputField(description="The image to get the channel from") channel: IMAGE_CHANNELS = InputField(default="A", description="The channel to get") def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + image = context.images.get_pil(self.image.image_name) channel_image = image.getchannel(self.channel) - image_dto = context.services.images.create( - image=channel_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=channel_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"] @@ -357,35 +271,22 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Convert Image Mode", tags=["image", "convert"], category="image", - version="1.2.0", + version="1.2.1", ) -class ImageConvertInvocation(BaseInvocation, WithMetadata): +class ImageConvertInvocation(BaseInvocation, WithMetadata, WithBoard): """Converts an image to a different mode.""" image: ImageField = InputField(description="The image to convert") mode: IMAGE_MODES = InputField(default="L", description="The mode to convert to") def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + image = context.images.get_pil(self.image.image_name) converted_image = image.convert(self.mode) - image_dto = context.services.images.create( - image=converted_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=converted_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -393,9 +294,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Blur Image", tags=["image", "blur"], category="image", - version="1.2.0", + version="1.2.1", ) -class ImageBlurInvocation(BaseInvocation, WithMetadata): +class ImageBlurInvocation(BaseInvocation, WithMetadata, WithBoard): """Blurs an image""" image: ImageField = InputField(description="The image to blur") @@ -404,29 +305,16 @@ class ImageBlurInvocation(BaseInvocation, WithMetadata): blur_type: Literal["gaussian", "box"] = InputField(default="gaussian", description="The type of blur") def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + image = context.images.get_pil(self.image.image_name) blur = ( ImageFilter.GaussianBlur(self.radius) if self.blur_type == "gaussian" else ImageFilter.BoxBlur(self.radius) ) blur_image = image.filter(blur) - image_dto = context.services.images.create( - image=blur_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=blur_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -434,10 +322,10 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Unsharp Mask", tags=["image", "unsharp_mask"], category="image", - version="1.2.0", + version="1.2.1", classification=Classification.Beta, ) -class UnsharpMaskInvocation(BaseInvocation, WithMetadata): +class UnsharpMaskInvocation(BaseInvocation, WithMetadata, WithBoard): """Applies an unsharp mask filter to an image""" image: ImageField = InputField(description="The image to use") @@ -451,7 +339,7 @@ def array_from_pil(self, img): return numpy.array(img) / 255 def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + image = context.images.get_pil(self.image.image_name) mode = image.mode alpha_channel = image.getchannel("A") if mode == "RGBA" else None @@ -469,16 +357,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: if alpha_channel is not None: image.putalpha(alpha_channel) - image_dto = context.services.images.create( - image=image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=image) return ImageOutput( image=ImageField(image_name=image_dto.image_name), @@ -512,9 +391,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Resize Image", tags=["image", "resize"], category="image", - version="1.2.0", + version="1.2.1", ) -class ImageResizeInvocation(BaseInvocation, WithMetadata): +class ImageResizeInvocation(BaseInvocation, WithMetadata, WithBoard): """Resizes an image to specific dimensions""" image: ImageField = InputField(description="The image to resize") @@ -523,7 +402,7 @@ class ImageResizeInvocation(BaseInvocation, WithMetadata): resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode") def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + image = context.images.get_pil(self.image.image_name) resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] @@ -532,22 +411,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: resample=resample_mode, ) - image_dto = context.services.images.create( - image=resize_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=resize_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -555,9 +421,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Scale Image", tags=["image", "scale"], category="image", - version="1.2.0", + version="1.2.1", ) -class ImageScaleInvocation(BaseInvocation, WithMetadata): +class ImageScaleInvocation(BaseInvocation, WithMetadata, WithBoard): """Scales an image by a factor""" image: ImageField = InputField(description="The image to scale") @@ -569,7 +435,7 @@ class ImageScaleInvocation(BaseInvocation, WithMetadata): resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode") def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + image = context.images.get_pil(self.image.image_name) resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] width = int(image.width * self.scale_factor) @@ -580,22 +446,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: resample=resample_mode, ) - image_dto = context.services.images.create( - image=resize_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=resize_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -603,9 +456,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Lerp Image", tags=["image", "lerp"], category="image", - version="1.2.0", + version="1.2.1", ) -class ImageLerpInvocation(BaseInvocation, WithMetadata): +class ImageLerpInvocation(BaseInvocation, WithMetadata, WithBoard): """Linear interpolation of all pixels of an image""" image: ImageField = InputField(description="The image to lerp") @@ -613,29 +466,16 @@ class ImageLerpInvocation(BaseInvocation, WithMetadata): max: int = InputField(default=255, ge=0, le=255, description="The maximum output value") def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + image = context.images.get_pil(self.image.image_name) image_arr = numpy.asarray(image, dtype=numpy.float32) / 255 image_arr = image_arr * (self.max - self.min) + self.min lerp_image = Image.fromarray(numpy.uint8(image_arr)) - image_dto = context.services.images.create( - image=lerp_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=lerp_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -643,9 +483,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Inverse Lerp Image", tags=["image", "ilerp"], category="image", - version="1.2.0", + version="1.2.1", ) -class ImageInverseLerpInvocation(BaseInvocation, WithMetadata): +class ImageInverseLerpInvocation(BaseInvocation, WithMetadata, WithBoard): """Inverse linear interpolation of all pixels of an image""" image: ImageField = InputField(description="The image to lerp") @@ -653,29 +493,16 @@ class ImageInverseLerpInvocation(BaseInvocation, WithMetadata): max: int = InputField(default=255, ge=0, le=255, description="The maximum input value") def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + image = context.images.get_pil(self.image.image_name) image_arr = numpy.asarray(image, dtype=numpy.float32) image_arr = numpy.minimum(numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1) * 255 # type: ignore [assignment] ilerp_image = Image.fromarray(numpy.uint8(image_arr)) - image_dto = context.services.images.create( - image=ilerp_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=ilerp_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -683,17 +510,17 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Blur NSFW Image", tags=["image", "nsfw"], category="image", - version="1.2.0", + version="1.2.1", ) -class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata): +class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithBoard): """Add blur to NSFW-flagged images""" image: ImageField = InputField(description="The image to check") def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + image = context.images.get_pil(self.image.image_name) - logger = context.services.logger + logger = context.logger logger.debug("Running NSFW checker") if SafetyChecker.has_nsfw_concept(image): logger.info("A potentially NSFW image has been detected. Image will be blurred.") @@ -702,22 +529,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: blurry_image.paste(caution, (0, 0), caution) image = blurry_image - image_dto = context.services.images.create( - image=image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) def _get_caution_img(self) -> Image.Image: import invokeai.app.assets.images as image_assets @@ -731,33 +545,20 @@ def _get_caution_img(self) -> Image.Image: title="Add Invisible Watermark", tags=["image", "watermark"], category="image", - version="1.2.0", + version="1.2.1", ) -class ImageWatermarkInvocation(BaseInvocation, WithMetadata): +class ImageWatermarkInvocation(BaseInvocation, WithMetadata, WithBoard): """Add an invisible watermark to an image""" image: ImageField = InputField(description="The image to check") text: str = InputField(default="InvokeAI", description="Watermark text") def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + image = context.images.get_pil(self.image.image_name) new_image = InvisibleWatermark.add_watermark(image, self.text) - image_dto = context.services.images.create( - image=new_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=new_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -765,9 +566,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Mask Edge", tags=["image", "mask", "inpaint"], category="image", - version="1.2.0", + version="1.2.1", ) -class MaskEdgeInvocation(BaseInvocation, WithMetadata): +class MaskEdgeInvocation(BaseInvocation, WithMetadata, WithBoard): """Applies an edge mask to an image""" image: ImageField = InputField(description="The image to apply the mask to") @@ -779,7 +580,7 @@ class MaskEdgeInvocation(BaseInvocation, WithMetadata): ) def invoke(self, context: InvocationContext) -> ImageOutput: - mask = context.services.images.get_pil_image(self.image.image_name).convert("L") + mask = context.images.get_pil(self.image.image_name).convert("L") npimg = numpy.asarray(mask, dtype=numpy.uint8) npgradient = numpy.uint8(255 * (1.0 - numpy.floor(numpy.abs(0.5 - numpy.float32(npimg) / 255.0) * 2.0))) @@ -794,22 +595,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: new_mask = ImageOps.invert(new_mask) - image_dto = context.services.images.create( - image=new_mask, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.MASK, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=new_mask, image_category=ImageCategory.MASK) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -817,36 +605,23 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Combine Masks", tags=["image", "mask", "multiply"], category="image", - version="1.2.0", + version="1.2.1", ) -class MaskCombineInvocation(BaseInvocation, WithMetadata): +class MaskCombineInvocation(BaseInvocation, WithMetadata, WithBoard): """Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`.""" mask1: ImageField = InputField(description="The first mask to combine") mask2: ImageField = InputField(description="The second image to combine") def invoke(self, context: InvocationContext) -> ImageOutput: - mask1 = context.services.images.get_pil_image(self.mask1.image_name).convert("L") - mask2 = context.services.images.get_pil_image(self.mask2.image_name).convert("L") + mask1 = context.images.get_pil(self.mask1.image_name).convert("L") + mask2 = context.images.get_pil(self.mask2.image_name).convert("L") combined_mask = ImageChops.multiply(mask1, mask2) - image_dto = context.services.images.create( - image=combined_mask, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=combined_mask, image_category=ImageCategory.MASK) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -854,9 +629,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Color Correct", tags=["image", "color"], category="image", - version="1.2.0", + version="1.2.1", ) -class ColorCorrectInvocation(BaseInvocation, WithMetadata): +class ColorCorrectInvocation(BaseInvocation, WithMetadata, WithBoard): """ Shifts the colors of a target image to match the reference image, optionally using a mask to only color-correct certain regions of the target image. @@ -870,11 +645,11 @@ class ColorCorrectInvocation(BaseInvocation, WithMetadata): def invoke(self, context: InvocationContext) -> ImageOutput: pil_init_mask = None if self.mask is not None: - pil_init_mask = context.services.images.get_pil_image(self.mask.image_name).convert("L") + pil_init_mask = context.images.get_pil(self.mask.image_name).convert("L") - init_image = context.services.images.get_pil_image(self.reference.image_name) + init_image = context.images.get_pil(self.reference.image_name) - result = context.services.images.get_pil_image(self.image.image_name).convert("RGBA") + result = context.images.get_pil(self.image.image_name).convert("RGBA") # if init_image is None or init_mask is None: # return result @@ -948,22 +723,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: # Paste original on color-corrected generation (using blurred mask) matched_result.paste(init_image, (0, 0), mask=multiplied_blurred_init_mask) - image_dto = context.services.images.create( - image=matched_result, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=matched_result) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -971,16 +733,16 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Adjust Image Hue", tags=["image", "hue"], category="image", - version="1.2.0", + version="1.2.1", ) -class ImageHueAdjustmentInvocation(BaseInvocation, WithMetadata): +class ImageHueAdjustmentInvocation(BaseInvocation, WithMetadata, WithBoard): """Adjusts the Hue of an image.""" image: ImageField = InputField(description="The image to adjust") hue: int = InputField(default=0, description="The degrees by which to rotate the hue, 0-360") def invoke(self, context: InvocationContext) -> ImageOutput: - pil_image = context.services.images.get_pil_image(self.image.image_name) + pil_image = context.images.get_pil(self.image.image_name) # Convert image to HSV color space hsv_image = numpy.array(pil_image.convert("HSV")) @@ -994,24 +756,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: # Convert back to PIL format and to original color mode pil_image = Image.fromarray(hsv_image, mode="HSV").convert("RGBA") - image_dto = context.services.images.create( - image=pil_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - is_intermediate=self.is_intermediate, - session_id=context.graph_execution_state_id, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=pil_image) - return ImageOutput( - image=ImageField( - image_name=image_dto.image_name, - ), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) COLOR_CHANNELS = Literal[ @@ -1075,9 +822,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: "value", ], category="image", - version="1.2.0", + version="1.2.1", ) -class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata): +class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata, WithBoard): """Add or subtract a value from a specific color channel of an image.""" image: ImageField = InputField(description="The image to adjust") @@ -1085,7 +832,7 @@ class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata): offset: int = InputField(default=0, ge=-255, le=255, description="The amount to adjust the channel by") def invoke(self, context: InvocationContext) -> ImageOutput: - pil_image = context.services.images.get_pil_image(self.image.image_name) + pil_image = context.images.get_pil(self.image.image_name) # extract the channel and mode from the input and reference tuple mode = CHANNEL_FORMATS[self.channel][0] @@ -1104,24 +851,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: # Convert back to RGBA format and output pil_image = Image.fromarray(converted_image.astype(numpy.uint8), mode=mode).convert("RGBA") - image_dto = context.services.images.create( - image=pil_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - is_intermediate=self.is_intermediate, - session_id=context.graph_execution_state_id, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=pil_image) - return ImageOutput( - image=ImageField( - image_name=image_dto.image_name, - ), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -1146,9 +878,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: "value", ], category="image", - version="1.2.0", + version="1.2.1", ) -class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata): +class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata, WithBoard): """Scale a specific color channel of an image.""" image: ImageField = InputField(description="The image to adjust") @@ -1157,7 +889,7 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata): invert_channel: bool = InputField(default=False, description="Invert the channel after scaling") def invoke(self, context: InvocationContext) -> ImageOutput: - pil_image = context.services.images.get_pil_image(self.image.image_name) + pil_image = context.images.get_pil(self.image.image_name) # extract the channel and mode from the input and reference tuple mode = CHANNEL_FORMATS[self.channel][0] @@ -1180,24 +912,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: # Convert back to RGBA format and output pil_image = Image.fromarray(converted_image.astype(numpy.uint8), mode=mode).convert("RGBA") - image_dto = context.services.images.create( - image=pil_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - is_intermediate=self.is_intermediate, - session_id=context.graph_execution_state_id, - workflow=context.workflow, - metadata=self.metadata, - ) + image_dto = context.images.save(image=pil_image) - return ImageOutput( - image=ImageField( - image_name=image_dto.image_name, - ), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -1205,64 +922,17 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Save Image", tags=["primitives", "image"], category="primitives", - version="1.2.0", + version="1.2.1", use_cache=False, ) -class SaveImageInvocation(BaseInvocation, WithMetadata): +class SaveImageInvocation(BaseInvocation, WithMetadata, WithBoard): """Saves an image. Unlike an image primitive, this invocation stores a copy of the image.""" image: ImageField = InputField(description=FieldDescriptions.image) - board: BoardField = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct) def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) - - image_dto = context.services.images.create( - image=image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - board_id=self.board.board_id if self.board else None, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) - - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + image = context.images.get_pil(self.image.image_name) + image_dto = context.images.save(image=image) -@invocation( - "linear_ui_output", - title="Linear UI Image Output", - tags=["primitives", "image"], - category="primitives", - version="1.0.1", - use_cache=False, -) -class LinearUIOutputInvocation(BaseInvocation, WithMetadata): - """Handles Linear UI Image Outputting tasks.""" - - image: ImageField = InputField(description=FieldDescriptions.image) - board: Optional[BoardField] = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct) - - def invoke(self, context: InvocationContext) -> ImageOutput: - image_dto = context.services.images.get_dto(self.image.image_name) - - if self.board: - context.services.board_images.add_image_to_board(self.board.board_id, self.image.image_name) - - if image_dto.is_intermediate != self.is_intermediate: - context.services.images.update( - self.image.image_name, changes=ImageRecordChanges(is_intermediate=self.is_intermediate) - ) - - return ImageOutput( - image=ImageField(image_name=self.image.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index c3d00bb1330..53f6f4732fe 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -6,14 +6,16 @@ import numpy as np from PIL import Image, ImageOps -from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin +from invokeai.app.invocations.fields import ColorField, ImageField +from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.misc import SEED_MAX from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint from invokeai.backend.image_util.lama import LaMA from invokeai.backend.image_util.patchmatch import PatchMatch -from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, invocation +from .baseinvocation import BaseInvocation, invocation +from .fields import InputField, WithBoard, WithMetadata from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES @@ -118,8 +120,8 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int] return si -@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0") -class InfillColorInvocation(BaseInvocation, WithMetadata): +@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1") +class InfillColorInvocation(BaseInvocation, WithMetadata, WithBoard): """Infills transparent areas of an image with a solid color""" image: ImageField = InputField(description="The image to infill") @@ -129,33 +131,20 @@ class InfillColorInvocation(BaseInvocation, WithMetadata): ) def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + image = context.images.get_pil(self.image.image_name) solid_bg = Image.new("RGBA", image.size, self.color.tuple()) infilled = Image.alpha_composite(solid_bg, image.convert("RGBA")) infilled.paste(image, (0, 0), image.split()[-1]) - image_dto = context.services.images.create( - image=infilled, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=infilled) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) -@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1") -class InfillTileInvocation(BaseInvocation, WithMetadata): +@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2") +class InfillTileInvocation(BaseInvocation, WithMetadata, WithBoard): """Infills transparent areas of an image with tiles of the image""" image: ImageField = InputField(description="The image to infill") @@ -168,33 +157,20 @@ class InfillTileInvocation(BaseInvocation, WithMetadata): ) def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + image = context.images.get_pil(self.image.image_name) infilled = tile_fill_missing(image.copy(), seed=self.seed, tile_size=self.tile_size) infilled.paste(image, (0, 0), image.split()[-1]) - image_dto = context.services.images.create( - image=infilled, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=infilled) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( - "infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0" + "infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1" ) -class InfillPatchMatchInvocation(BaseInvocation, WithMetadata): +class InfillPatchMatchInvocation(BaseInvocation, WithMetadata, WithBoard): """Infills transparent areas of an image using the PatchMatch algorithm""" image: ImageField = InputField(description="The image to infill") @@ -202,7 +178,7 @@ class InfillPatchMatchInvocation(BaseInvocation, WithMetadata): resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode") def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name).convert("RGBA") + image = context.images.get_pil(self.image.image_name).convert("RGBA") resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] @@ -227,77 +203,38 @@ def invoke(self, context: InvocationContext) -> ImageOutput: infilled.paste(image, (0, 0), mask=image.split()[-1]) # image.paste(infilled, (0, 0), mask=image.split()[-1]) - image_dto = context.services.images.create( - image=infilled, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=infilled) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) -@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0") -class LaMaInfillInvocation(BaseInvocation, WithMetadata): +@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1") +class LaMaInfillInvocation(BaseInvocation, WithMetadata, WithBoard): """Infills transparent areas of an image using the LaMa model""" image: ImageField = InputField(description="The image to infill") def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + image = context.images.get_pil(self.image.image_name) infilled = infill_lama(image.copy()) - image_dto = context.services.images.create( - image=infilled, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=infilled) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) -@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0") -class CV2InfillInvocation(BaseInvocation, WithMetadata): +@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1") +class CV2InfillInvocation(BaseInvocation, WithMetadata, WithBoard): """Infills transparent areas of an image using OpenCV Inpainting""" image: ImageField = InputField(description="The image to infill") def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + image = context.images.get_pil(self.image.image_name) infilled = infill_cv2(image.copy()) - image_dto = context.services.images.create( - image=infilled, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=infilled) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index 6bd28896244..15e254010b5 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -1,38 +1,29 @@ -import os from builtins import float from typing import List, Union -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from pydantic import BaseModel, Field, field_validator, model_validator +from typing_extensions import Self from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - Input, - InputField, - InvocationContext, - OutputField, invocation, invocation_output, ) +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights -from invokeai.app.shared.fields import FieldDescriptions -from invokeai.backend.model_management.models.base import BaseModelType, ModelType -from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.model_manager.config import BaseModelType, ModelType +# LS: Consider moving these two classes into model.py class IPAdapterModelField(BaseModel): - model_name: str = Field(description="Name of the IP-Adapter model") - base_model: BaseModelType = Field(description="Base model") - - model_config = ConfigDict(protected_namespaces=()) + key: str = Field(description="Key to the IP-Adapter model") class CLIPVisionModelField(BaseModel): - model_name: str = Field(description="Name of the CLIP Vision image encoder model") - base_model: BaseModelType = Field(description="Base model (usually 'Any')") - - model_config = ConfigDict(protected_namespaces=()) + key: str = Field(description="Key to the CLIP Vision image encoder model") class IPAdapterField(BaseModel): @@ -49,12 +40,12 @@ class IPAdapterField(BaseModel): @field_validator("weight") @classmethod - def validate_ip_adapter_weight(cls, v): + def validate_ip_adapter_weight(cls, v: float) -> float: validate_weights(v) return v @model_validator(mode="after") - def validate_begin_end_step_percent(self): + def validate_begin_end_step_percent(self) -> Self: validate_begin_end_step(self.begin_step_percent, self.end_step_percent) return self @@ -65,7 +56,7 @@ class IPAdapterOutput(BaseInvocationOutput): ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter") -@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.1.1") +@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.1.2") class IPAdapterInvocation(BaseInvocation): """Collects IP-Adapter info to pass to other nodes.""" @@ -87,33 +78,25 @@ class IPAdapterInvocation(BaseInvocation): @field_validator("weight") @classmethod - def validate_ip_adapter_weight(cls, v): + def validate_ip_adapter_weight(cls, v: float) -> float: validate_weights(v) return v @model_validator(mode="after") - def validate_begin_end_step_percent(self): + def validate_begin_end_step_percent(self) -> Self: validate_begin_end_step(self.begin_step_percent, self.end_step_percent) return self def invoke(self, context: InvocationContext) -> IPAdapterOutput: # Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model. - ip_adapter_info = context.services.model_manager.model_info( - self.ip_adapter_model.model_name, self.ip_adapter_model.base_model, ModelType.IPAdapter - ) - # HACK(ryand): This is bad for a couple of reasons: 1) we are bypassing the model manager to read the model - # directly, and 2) we are reading from disk every time this invocation is called without caching the result. - # A better solution would be to store the image encoder model reference in the IP-Adapter model info, but this - # is currently messy due to differences between how the model info is generated when installing a model from - # disk vs. downloading the model. - image_encoder_model_id = get_ip_adapter_image_encoder_model_id( - os.path.join(context.services.configuration.get_config().models_path, ip_adapter_info["path"]) - ) + ip_adapter_info = context.models.get_config(self.ip_adapter_model.key) + image_encoder_model_id = ip_adapter_info.image_encoder_model_id image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip() - image_encoder_model = CLIPVisionModelField( - model_name=image_encoder_model_name, - base_model=BaseModelType.Any, + image_encoder_models = context.models.search_by_attrs( + model_name=image_encoder_model_name, base_model=BaseModelType.Any, model_type=ModelType.CLIPVision ) + assert len(image_encoder_models) == 1 + image_encoder_model = CLIPVisionModelField(key=image_encoder_models[0].key) return IPAdapterOutput( ip_adapter=IPAdapterField( image=self.image, diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index b77363ceb86..5dd0eb074d5 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -3,13 +3,15 @@ import math from contextlib import ExitStack from functools import singledispatchmethod -from typing import List, Literal, Optional, Union +from typing import Any, Iterator, List, Literal, Optional, Tuple, Union import einops import numpy as np +import numpy.typing as npt import torch import torchvision.transforms as T from diffusers import AutoencoderKL, AutoencoderTiny +from diffusers.configuration_utils import ConfigMixin from diffusers.image_processor import VaeImageProcessor from diffusers.models.adapter import T2IAdapter from diffusers.models.attention_processor import ( @@ -18,34 +20,44 @@ LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.schedulers import DPMSolverSDEScheduler from diffusers.schedulers import SchedulerMixin as Scheduler +from PIL import Image from pydantic import field_validator from torchvision.transforms.functional import resize as tv_resize +from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES +from invokeai.app.invocations.fields import ( + ConditioningField, + DenoiseMaskField, + FieldDescriptions, + ImageField, + Input, + InputField, + LatentsField, + OutputField, + UIType, + WithBoard, + WithMetadata, +) from invokeai.app.invocations.ip_adapter import IPAdapterField from invokeai.app.invocations.primitives import ( - DenoiseMaskField, DenoiseMaskOutput, - ImageField, ImageOutput, - LatentsField, LatentsOutput, - build_latents_output, ) from invokeai.app.invocations.t2i_adapter import T2IAdapterField -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin -from invokeai.app.shared.fields import FieldDescriptions +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.controlnet_utils import prepare_control_image -from invokeai.app.util.step_callback import stable_diffusion_step_callback +from invokeai.backend.embeddings.lora import LoRAModelRaw +from invokeai.backend.embeddings.model_patcher import ModelPatcher from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus -from invokeai.backend.model_management.models import ModelType, SilenceWarnings +from invokeai.backend.model_manager import BaseModelType, LoadedModel +from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo +from invokeai.backend.util.silence_warnings import SilenceWarnings -from ...backend.model_management.lora import ModelPatcher -from ...backend.model_management.models import BaseModelType -from ...backend.model_management.seamless import set_seamless -from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion.diffusers_pipeline import ( ControlNetData, IPAdapterData, @@ -59,16 +71,9 @@ from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - Input, - InputField, - InvocationContext, - OutputField, - UIType, - WithMetadata, invocation, invocation_output, ) -from .compel import ConditioningField from .controlnet_image_processors import ControlField from .model import ModelInfo, UNetField, VaeField @@ -77,18 +82,10 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device()) -SAMPLER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())] - -# HACK: Many nodes are currently hard-coded to use a fixed latent scale factor of 8. This is fragile, and will need to -# be addressed if future models use a different latent scale factor. Also, note that there may be places where the scale -# factor is hard-coded to a literal '8' rather than using this constant. -# The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1. -LATENT_SCALE_FACTOR = 8 - @invocation_output("scheduler_output") class SchedulerOutput(BaseInvocationOutput): - scheduler: SAMPLER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler) + scheduler: SCHEDULER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler) @invocation( @@ -101,7 +98,7 @@ class SchedulerOutput(BaseInvocationOutput): class SchedulerInvocation(BaseInvocation): """Selects a scheduler.""" - scheduler: SAMPLER_NAME_VALUES = InputField( + scheduler: SCHEDULER_NAME_VALUES = InputField( default="euler", description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler, @@ -116,7 +113,7 @@ def invoke(self, context: InvocationContext) -> SchedulerOutput: title="Create Denoise Mask", tags=["mask", "denoise"], category="latents", - version="1.0.0", + version="1.0.1", ) class CreateDenoiseMaskInvocation(BaseInvocation): """Creates mask for denoising model run.""" @@ -131,10 +128,10 @@ class CreateDenoiseMaskInvocation(BaseInvocation): ui_order=4, ) - def prep_mask_tensor(self, mask_image): + def prep_mask_tensor(self, mask_image: Image) -> torch.Tensor: if mask_image.mode != "L": mask_image = mask_image.convert("L") - mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False) + mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False) if mask_tensor.dim() == 3: mask_tensor = mask_tensor.unsqueeze(0) # if shape is not None: @@ -144,41 +141,34 @@ def prep_mask_tensor(self, mask_image): @torch.no_grad() def invoke(self, context: InvocationContext) -> DenoiseMaskOutput: if self.image is not None: - image = context.services.images.get_pil_image(self.image.image_name) - image = image_resized_to_grid_as_tensor(image.convert("RGB")) - if image.dim() == 3: - image = image.unsqueeze(0) + image = context.images.get_pil(self.image.image_name) + image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) + if image_tensor.dim() == 3: + image_tensor = image_tensor.unsqueeze(0) else: - image = None + image_tensor = None mask = self.prep_mask_tensor( - context.services.images.get_pil_image(self.mask.image_name), + context.images.get_pil(self.mask.image_name), ) - if image is not None: - vae_info = context.services.model_manager.get_model( - **self.vae.vae.model_dump(), - context=context, - ) + if image_tensor is not None: + vae_info = context.models.load(**self.vae.vae.model_dump()) - img_mask = tv_resize(mask, image.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) - masked_image = image * torch.where(img_mask < 0.5, 0.0, 1.0) + img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) + masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0) # TODO: masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone()) - masked_latents_name = f"{context.graph_execution_state_id}__{self.id}_masked_latents" - context.services.latents.save(masked_latents_name, masked_latents) + masked_latents_name = context.tensors.save(tensor=masked_latents) else: masked_latents_name = None - mask_name = f"{context.graph_execution_state_id}__{self.id}_mask" - context.services.latents.save(mask_name, mask) + mask_name = context.tensors.save(tensor=mask) - return DenoiseMaskOutput( - denoise_mask=DenoiseMaskField( - mask_name=mask_name, - masked_latents_name=masked_latents_name, - ), + return DenoiseMaskOutput.build( + mask_name=mask_name, + masked_latents_name=masked_latents_name, ) @@ -189,10 +179,7 @@ def get_scheduler( seed: int, ) -> Scheduler: scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"]) - orig_scheduler_info = context.services.model_manager.get_model( - **scheduler_info.model_dump(), - context=context, - ) + orig_scheduler_info = context.models.load(**scheduler_info.model_dump()) with orig_scheduler_info as orig_scheduler: scheduler_config = orig_scheduler.config @@ -200,7 +187,7 @@ def get_scheduler( scheduler_config = scheduler_config["_backup"] scheduler_config = { **scheduler_config, - **scheduler_extra_config, + **scheduler_extra_config, # FIXME "_backup": scheduler_config, } @@ -213,6 +200,7 @@ def get_scheduler( # hack copied over from generate.py if not hasattr(scheduler, "uses_inpainting_model"): scheduler.uses_inpainting_model = lambda: False + assert isinstance(scheduler, Scheduler) return scheduler @@ -221,7 +209,7 @@ def get_scheduler( title="Denoise Latents", tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"], category="latents", - version="1.5.1", + version="1.5.2", ) class DenoiseLatentsInvocation(BaseInvocation): """Denoises noisy latents to decodable images""" @@ -249,7 +237,7 @@ class DenoiseLatentsInvocation(BaseInvocation): description=FieldDescriptions.denoising_start, ) denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end) - scheduler: SAMPLER_NAME_VALUES = InputField( + scheduler: SCHEDULER_NAME_VALUES = InputField( default="euler", description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler, @@ -296,7 +284,7 @@ class DenoiseLatentsInvocation(BaseInvocation): ) @field_validator("cfg_scale") - def ge_one(cls, v): + def ge_one(cls, v: Union[List[float], float]) -> Union[List[float], float]: """validate that all cfg_scale values are >= 1""" if isinstance(v, list): for i in v: @@ -307,34 +295,18 @@ def ge_one(cls, v): raise ValueError("cfg_scale must be greater than 1") return v - # TODO: pass this an emitter method or something? or a session for dispatching? - def dispatch_progress( - self, - context: InvocationContext, - source_node_id: str, - intermediate_state: PipelineIntermediateState, - base_model: BaseModelType, - ) -> None: - stable_diffusion_step_callback( - context=context, - intermediate_state=intermediate_state, - node=self.model_dump(), - source_node_id=source_node_id, - base_model=base_model, - ) - def get_conditioning_data( self, context: InvocationContext, - scheduler, - unet, - seed, + scheduler: Scheduler, + unet: UNet2DConditionModel, + seed: int, ) -> ConditioningData: - positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name) + positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name) c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) extra_conditioning_info = c.extra_conditioning - negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name) + negative_cond_data = context.conditioning.load(self.negative_conditioning.conditioning_name) uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) conditioning_data = ConditioningData( @@ -351,7 +323,7 @@ def get_conditioning_data( ), ) - conditioning_data = conditioning_data.add_scheduler_args_if_applicable( + conditioning_data = conditioning_data.add_scheduler_args_if_applicable( # FIXME scheduler, # for ddim scheduler eta=0.0, # ddim_eta @@ -363,8 +335,8 @@ def get_conditioning_data( def create_pipeline( self, - unet, - scheduler, + unet: UNet2DConditionModel, + scheduler: Scheduler, ) -> StableDiffusionGeneratorPipeline: # TODO: # configure_model_padding( @@ -375,10 +347,10 @@ def create_pipeline( class FakeVae: class FakeVaeConfig: - def __init__(self): + def __init__(self) -> None: self.block_out_channels = [0] - def __init__(self): + def __init__(self) -> None: self.config = FakeVae.FakeVaeConfig() return StableDiffusionGeneratorPipeline( @@ -395,11 +367,11 @@ def __init__(self): def prep_control_data( self, context: InvocationContext, - control_input: Union[ControlField, List[ControlField]], + control_input: Optional[Union[ControlField, List[ControlField]]], latents_shape: List[int], exit_stack: ExitStack, do_classifier_free_guidance: bool = True, - ) -> List[ControlNetData]: + ) -> Optional[List[ControlNetData]]: # Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR. control_height_resize = latents_shape[2] * LATENT_SCALE_FACTOR control_width_resize = latents_shape[3] * LATENT_SCALE_FACTOR @@ -421,18 +393,11 @@ def prep_control_data( # and if weight is None, populate with default 1.0? controlnet_data = [] for control_info in control_list: - control_model = exit_stack.enter_context( - context.services.model_manager.get_model( - model_name=control_info.control_model.model_name, - model_type=ModelType.ControlNet, - base_model=control_info.control_model.base_model, - context=context, - ) - ) + control_model = exit_stack.enter_context(context.models.load(key=control_info.control_model.key)) # control_models.append(control_model) control_image_field = control_info.image - input_image = context.services.images.get_pil_image(control_image_field.image_name) + input_image = context.images.get_pil(control_image_field.image_name) # self.image.image_type, self.image.image_name # FIXME: still need to test with different widths, heights, devices, dtypes # and add in batch_size, num_images_per_prompt? @@ -490,27 +455,17 @@ def prep_ip_adapter_data( conditioning_data.ip_adapter_conditioning = [] for single_ip_adapter in ip_adapter: ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context( - context.services.model_manager.get_model( - model_name=single_ip_adapter.ip_adapter_model.model_name, - model_type=ModelType.IPAdapter, - base_model=single_ip_adapter.ip_adapter_model.base_model, - context=context, - ) + context.models.load(key=single_ip_adapter.ip_adapter_model.key) ) - image_encoder_model_info = context.services.model_manager.get_model( - model_name=single_ip_adapter.image_encoder_model.model_name, - model_type=ModelType.CLIPVision, - base_model=single_ip_adapter.image_encoder_model.base_model, - context=context, - ) + image_encoder_model_info = context.models.load(key=single_ip_adapter.image_encoder_model.key) # `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here. - single_ipa_images = single_ip_adapter.image - if not isinstance(single_ipa_images, list): - single_ipa_images = [single_ipa_images] + single_ipa_image_fields = single_ip_adapter.image + if not isinstance(single_ipa_image_fields, list): + single_ipa_image_fields = [single_ipa_image_fields] - single_ipa_images = [context.services.images.get_pil_image(image.image_name) for image in single_ipa_images] + single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_image_fields] # TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other # models are needed in memory. This would help to reduce peak memory utilization in low-memory environments. @@ -554,23 +509,16 @@ def run_t2i_adapters( t2i_adapter_data = [] for t2i_adapter_field in t2i_adapter: - t2i_adapter_model_info = context.services.model_manager.get_model( - model_name=t2i_adapter_field.t2i_adapter_model.model_name, - model_type=ModelType.T2IAdapter, - base_model=t2i_adapter_field.t2i_adapter_model.base_model, - context=context, - ) - image = context.services.images.get_pil_image(t2i_adapter_field.image.image_name) + t2i_adapter_model_info = context.models.load(key=t2i_adapter_field.t2i_adapter_model.key) + image = context.images.get_pil(t2i_adapter_field.image.image_name) # The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally. - if t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusion1: + if t2i_adapter_model_info.base == BaseModelType.StableDiffusion1: max_unet_downscale = 8 - elif t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusionXL: + elif t2i_adapter_model_info.base == BaseModelType.StableDiffusionXL: max_unet_downscale = 4 else: - raise ValueError( - f"Unexpected T2I-Adapter base model type: '{t2i_adapter_field.t2i_adapter_model.base_model}'." - ) + raise ValueError(f"Unexpected T2I-Adapter base model type: '{t2i_adapter_model_info.base}'.") t2i_adapter_model: T2IAdapter with t2i_adapter_model_info as t2i_adapter_model: @@ -593,7 +541,7 @@ def run_t2i_adapters( do_classifier_free_guidance=False, width=t2i_input_width, height=t2i_input_height, - num_channels=t2i_adapter_model.config.in_channels, + num_channels=t2i_adapter_model.config["in_channels"], # mypy treats this as a FrozenDict device=t2i_adapter_model.device, dtype=t2i_adapter_model.dtype, resize_mode=t2i_adapter_field.resize_mode, @@ -618,7 +566,15 @@ def run_t2i_adapters( # original idea by https://github.com/AmericanPresidentJimmyCarter # TODO: research more for second order schedulers timesteps - def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_end): + def init_scheduler( + self, + scheduler: Union[Scheduler, ConfigMixin], + device: torch.device, + steps: int, + denoising_start: float, + denoising_end: float, + ) -> Tuple[int, List[int], int]: + assert isinstance(scheduler, ConfigMixin) if scheduler.config.get("cpu_only", False): scheduler.set_timesteps(steps, device="cpu") timesteps = scheduler.timesteps.to(device=device) @@ -630,11 +586,11 @@ def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_en _timesteps = timesteps[:: scheduler.order] # get start timestep index - t_start_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_start))) + t_start_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_start))) t_start_idx = len(list(filter(lambda ts: ts >= t_start_val, _timesteps))) # get end timestep index - t_end_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_end))) + t_end_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_end))) t_end_idx = len(list(filter(lambda ts: ts >= t_end_val, _timesteps[t_start_idx:]))) # apply order to indexes @@ -647,14 +603,16 @@ def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_en return num_inference_steps, timesteps, init_timestep - def prep_inpaint_mask(self, context, latents): + def prep_inpaint_mask( + self, context: InvocationContext, latents: torch.Tensor + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: if self.denoise_mask is None: return None, None - mask = context.services.latents.get(self.denoise_mask.mask_name) + mask = context.tensors.load(self.denoise_mask.mask_name) mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) if self.denoise_mask.masked_latents_name is not None: - masked_latents = context.services.latents.get(self.denoise_mask.masked_latents_name) + masked_latents = context.tensors.load(self.denoise_mask.masked_latents_name) else: masked_latents = None @@ -666,11 +624,11 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: seed = None noise = None if self.noise is not None: - noise = context.services.latents.get(self.noise.latents_name) + noise = context.tensors.load(self.noise.latents_name) seed = self.noise.seed if self.latents is not None: - latents = context.services.latents.get(self.latents.latents_name) + latents = context.tensors.load(self.latents.latents_name) if seed is None: seed = self.latents.seed @@ -696,35 +654,30 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: do_classifier_free_guidance=True, ) - # Get the source node id (we are invoking the prepared node) - graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) - source_node_id = graph_execution_state.prepared_source_mapping[self.id] + # get the unet's config so that we can pass the base to dispatch_progress() + unet_config = context.models.get_config(self.unet.unet.key) - def step_callback(state: PipelineIntermediateState): - self.dispatch_progress(context, source_node_id, state, self.unet.unet.base_model) + def step_callback(state: PipelineIntermediateState) -> None: + context.util.sd_step_callback(state, unet_config.base) - def _lora_loader(): + def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in self.unet.loras: - lora_info = context.services.model_manager.get_model( - **lora.model_dump(exclude={"weight"}), - context=context, - ) - yield (lora_info.context.model, lora.weight) + lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) + yield (lora_info.model, lora.weight) del lora_info return - unet_info = context.services.model_manager.get_model( - **self.unet.unet.model_dump(), - context=context, - ) + unet_info = context.models.load(**self.unet.unet.model_dump()) + assert isinstance(unet_info.model, UNet2DConditionModel) with ( ExitStack() as exit_stack, - ModelPatcher.apply_freeu(unet_info.context.model, self.unet.freeu_config), - set_seamless(unet_info.context.model, self.unet.seamless_axes), + ModelPatcher.apply_freeu(unet_info.model, self.unet.freeu_config), + set_seamless(unet_info.model, self.unet.seamless_axes), # FIXME unet_info as unet, # Apply the LoRA after unet has been moved to its target device for faster patching. ModelPatcher.apply_lora_unet(unet, _lora_loader()), ): + assert isinstance(unet, UNet2DConditionModel) latents = latents.to(device=unet.device, dtype=unet.dtype) if noise is not None: noise = noise.to(device=unet.device, dtype=unet.dtype) @@ -792,9 +745,8 @@ def _lora_loader(): if choose_torch_device() == torch.device("mps"): mps.empty_cache() - name = f"{context.graph_execution_state_id}__{self.id}" - context.services.latents.save(name, result_latents) - return build_latents_output(latents_name=name, latents=result_latents, seed=seed) + name = context.tensors.save(tensor=result_latents) + return LatentsOutput.build(latents_name=name, latents=result_latents, seed=seed) @invocation( @@ -802,9 +754,9 @@ def _lora_loader(): title="Latents to Image", tags=["latents", "image", "vae", "l2i"], category="latents", - version="1.2.0", + version="1.2.1", ) -class LatentsToImageInvocation(BaseInvocation, WithMetadata): +class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): """Generates an image from latents.""" latents: LatentsField = InputField( @@ -820,14 +772,12 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata): @torch.no_grad() def invoke(self, context: InvocationContext) -> ImageOutput: - latents = context.services.latents.get(self.latents.latents_name) + latents = context.tensors.load(self.latents.latents_name) - vae_info = context.services.model_manager.get_model( - **self.vae.vae.model_dump(), - context=context, - ) + vae_info = context.models.load(**self.vae.vae.model_dump()) - with set_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae: + with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae: + assert isinstance(vae, torch.nn.Module) latents = latents.to(vae.device) if self.fp32: vae.to(dtype=torch.float32) @@ -854,7 +804,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: vae.to(dtype=torch.float16) latents = latents.half() - if self.tiled or context.services.configuration.tiled_decode: + if self.tiled or context.config.get().tiled_decode: vae.enable_tiling() else: vae.disable_tiling() @@ -878,22 +828,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: if choose_torch_device() == torch.device("mps"): mps.empty_cache() - image_dto = context.services.images.create( - image=image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"] @@ -904,7 +841,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Resize Latents", tags=["latents", "resize"], category="latents", - version="1.0.0", + version="1.0.1", ) class ResizeLatentsInvocation(BaseInvocation): """Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8.""" @@ -927,7 +864,7 @@ class ResizeLatentsInvocation(BaseInvocation): antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.services.latents.get(self.latents.latents_name) + latents = context.tensors.load(self.latents.latents_name) # TODO: device = choose_torch_device() @@ -945,10 +882,8 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: if device == torch.device("mps"): mps.empty_cache() - name = f"{context.graph_execution_state_id}__{self.id}" - # context.services.latents.set(name, resized_latents) - context.services.latents.save(name, resized_latents) - return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed) + name = context.tensors.save(tensor=resized_latents) + return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed) @invocation( @@ -956,7 +891,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: title="Scale Latents", tags=["latents", "resize"], category="latents", - version="1.0.0", + version="1.0.1", ) class ScaleLatentsInvocation(BaseInvocation): """Scales latents by a given factor.""" @@ -970,7 +905,7 @@ class ScaleLatentsInvocation(BaseInvocation): antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.services.latents.get(self.latents.latents_name) + latents = context.tensors.load(self.latents.latents_name) # TODO: device = choose_torch_device() @@ -989,10 +924,8 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: if device == torch.device("mps"): mps.empty_cache() - name = f"{context.graph_execution_state_id}__{self.id}" - # context.services.latents.set(name, resized_latents) - context.services.latents.save(name, resized_latents) - return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed) + name = context.tensors.save(tensor=resized_latents) + return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed) @invocation( @@ -1000,7 +933,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: title="Image to Latents", tags=["latents", "image", "vae", "i2l"], category="latents", - version="1.0.0", + version="1.0.1", ) class ImageToLatentsInvocation(BaseInvocation): """Encodes an image into latents.""" @@ -1016,8 +949,9 @@ class ImageToLatentsInvocation(BaseInvocation): fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32) @staticmethod - def vae_encode(vae_info, upcast, tiled, image_tensor): + def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor: with vae_info as vae: + assert isinstance(vae, torch.nn.Module) orig_dtype = vae.dtype if upcast: vae.to(dtype=torch.float32) @@ -1061,12 +995,9 @@ def vae_encode(vae_info, upcast, tiled, image_tensor): @torch.no_grad() def invoke(self, context: InvocationContext) -> LatentsOutput: - image = context.services.images.get_pil_image(self.image.image_name) + image = context.images.get_pil(self.image.image_name) - vae_info = context.services.model_manager.get_model( - **self.vae.vae.model_dump(), - context=context, - ) + vae_info = context.models.load(**self.vae.vae.model_dump()) image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) if image_tensor.dim() == 3: @@ -1074,22 +1005,26 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor) - name = f"{context.graph_execution_state_id}__{self.id}" latents = latents.to("cpu") - context.services.latents.save(name, latents) - return build_latents_output(latents_name=name, latents=latents, seed=None) + name = context.tensors.save(tensor=latents) + return LatentsOutput.build(latents_name=name, latents=latents, seed=None) @singledispatchmethod @staticmethod def _encode_to_tensor(vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor: + assert isinstance(vae, torch.nn.Module) image_tensor_dist = vae.encode(image_tensor).latent_dist - latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible! + latents: torch.Tensor = image_tensor_dist.sample().to( + dtype=vae.dtype + ) # FIXME: uses torch.randn. make reproducible! return latents @_encode_to_tensor.register @staticmethod def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor: - return vae.encode(image_tensor).latents + assert isinstance(vae, torch.nn.Module) + latents: torch.FloatTensor = vae.encode(image_tensor).latents + return latents @invocation( @@ -1097,7 +1032,7 @@ def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTenso title="Blend Latents", tags=["latents", "blend"], category="latents", - version="1.0.0", + version="1.0.1", ) class BlendLatentsInvocation(BaseInvocation): """Blend two latents using a given alpha. Latents must have same size.""" @@ -1113,8 +1048,8 @@ class BlendLatentsInvocation(BaseInvocation): alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents_a = context.services.latents.get(self.latents_a.latents_name) - latents_b = context.services.latents.get(self.latents_b.latents_name) + latents_a = context.tensors.load(self.latents_a.latents_name) + latents_b = context.tensors.load(self.latents_b.latents_name) if latents_a.shape != latents_b.shape: raise Exception("Latents to blend must be the same size.") @@ -1122,7 +1057,12 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: # TODO: device = choose_torch_device() - def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): + def slerp( + t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here? + v0: Union[torch.Tensor, npt.NDArray[Any]], + v1: Union[torch.Tensor, npt.NDArray[Any]], + DOT_THRESHOLD: float = 0.9995, + ) -> Union[torch.Tensor, npt.NDArray[Any]]: """ Spherical linear interpolation Args: @@ -1155,12 +1095,16 @@ def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): v2 = s0 * v0 + s1 * v1 if inputs_are_torch: - v2 = torch.from_numpy(v2).to(device) - - return v2 + v2_torch: torch.Tensor = torch.from_numpy(v2).to(device) + return v2_torch + else: + assert isinstance(v2, np.ndarray) + return v2 # blend - blended_latents = slerp(self.alpha, latents_a, latents_b) + bl = slerp(self.alpha, latents_a, latents_b) + assert isinstance(bl, torch.Tensor) + blended_latents: torch.Tensor = bl # for type checking convenience # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 blended_latents = blended_latents.to("cpu") @@ -1168,10 +1112,8 @@ def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): if device == torch.device("mps"): mps.empty_cache() - name = f"{context.graph_execution_state_id}__{self.id}" - # context.services.latents.set(name, resized_latents) - context.services.latents.save(name, blended_latents) - return build_latents_output(latents_name=name, latents=blended_latents) + name = context.tensors.save(tensor=blended_latents) + return LatentsOutput.build(latents_name=name, latents=blended_latents) # The Crop Latents node was copied from @skunkworxdark's implementation here: @@ -1181,7 +1123,7 @@ def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): title="Crop Latents", tags=["latents", "crop"], category="latents", - version="1.0.0", + version="1.0.1", ) # TODO(ryand): Named `CropLatentsCoreInvocation` to prevent a conflict with custom node `CropLatentsInvocation`. # Currently, if the class names conflict then 'GET /openapi.json' fails. @@ -1216,7 +1158,7 @@ class CropLatentsCoreInvocation(BaseInvocation): ) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.services.latents.get(self.latents.latents_name) + latents = context.tensors.load(self.latents.latents_name) x1 = self.x // LATENT_SCALE_FACTOR y1 = self.y // LATENT_SCALE_FACTOR @@ -1225,10 +1167,9 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: cropped_latents = latents[..., y1:y2, x1:x2] - name = f"{context.graph_execution_state_id}__{self.id}" - context.services.latents.save(name, cropped_latents) + name = context.tensors.save(tensor=cropped_latents) - return build_latents_output(latents_name=name, latents=cropped_latents) + return LatentsOutput.build(latents_name=name, latents=cropped_latents) @invocation_output("ideal_size_output") @@ -1256,15 +1197,16 @@ class IdealSizeInvocation(BaseInvocation): description="Amount to multiply the model's dimensions by when calculating the ideal size (may result in initial generation artifacts if too large)", ) - def trim_to_multiple_of(self, *args, multiple_of=LATENT_SCALE_FACTOR): + def trim_to_multiple_of(self, *args: int, multiple_of: int = LATENT_SCALE_FACTOR) -> Tuple[int, ...]: return tuple((x - x % multiple_of) for x in args) def invoke(self, context: InvocationContext) -> IdealSizeOutput: + unet_config = context.models.get_config(**self.unet.unet.model_dump()) aspect = self.width / self.height - dimension = 512 - if self.unet.unet.base_model == BaseModelType.StableDiffusion2: + dimension: float = 512 + if unet_config.base == BaseModelType.StableDiffusion2: dimension = 768 - elif self.unet.unet.base_model == BaseModelType.StableDiffusionXL: + elif unet_config.base == BaseModelType.StableDiffusionXL: dimension = 1024 dimension = dimension * self.multiplier min_dimension = math.floor(dimension * 0.5) diff --git a/invokeai/app/invocations/math.py b/invokeai/app/invocations/math.py index defc61275fe..83a092be69e 100644 --- a/invokeai/app/invocations/math.py +++ b/invokeai/app/invocations/math.py @@ -5,10 +5,11 @@ import numpy as np from pydantic import ValidationInfo, field_validator +from invokeai.app.invocations.fields import FieldDescriptions, InputField from invokeai.app.invocations.primitives import FloatOutput, IntegerOutput -from invokeai.app.shared.fields import FieldDescriptions +from invokeai.app.services.shared.invocation_context import InvocationContext -from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation +from .baseinvocation import BaseInvocation, invocation @invocation("add", title="Add Integers", tags=["math", "add"], category="math", version="1.0.0") diff --git a/invokeai/app/invocations/metadata.py b/invokeai/app/invocations/metadata.py index 14d66f8ef68..58edfab711a 100644 --- a/invokeai/app/invocations/metadata.py +++ b/invokeai/app/invocations/metadata.py @@ -5,20 +5,22 @@ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, + invocation, + invocation_output, +) +from invokeai.app.invocations.controlnet_image_processors import ControlField +from invokeai.app.invocations.fields import ( + FieldDescriptions, + ImageField, InputField, - InvocationContext, MetadataField, OutputField, UIType, - invocation, - invocation_output, ) -from invokeai.app.invocations.controlnet_image_processors import ControlField from invokeai.app.invocations.ip_adapter import IPAdapterModelField from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField -from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.t2i_adapter import T2IAdapterField -from invokeai.app.shared.fields import FieldDescriptions +from invokeai.app.services.shared.invocation_context import InvocationContext from ...version import __version__ diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 99dcc72999b..6087bc82db1 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -1,31 +1,24 @@ import copy from typing import List, Optional -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, Field -from invokeai.app.shared.fields import FieldDescriptions +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.shared.models import FreeUConfig -from ...backend.model_management import BaseModelType, ModelType, SubModelType +from ...backend.model_manager import SubModelType from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - Input, - InputField, - InvocationContext, - OutputField, invocation, invocation_output, ) class ModelInfo(BaseModel): - model_name: str = Field(description="Info to load submodel") - base_model: BaseModelType = Field(description="Base model") - model_type: ModelType = Field(description="Info to load submodel") - submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel") - - model_config = ConfigDict(protected_namespaces=()) + key: str = Field(description="Key of model as returned by ModelRecordServiceBase.get_model()") + submodel_type: Optional[SubModelType] = Field(default=None, description="Info to load submodel") class LoraInfo(ModelInfo): @@ -55,7 +48,7 @@ class VaeField(BaseModel): @invocation_output("unet_output") class UNetOutput(BaseInvocationOutput): - """Base class for invocations that output a UNet field""" + """Base class for invocations that output a UNet field.""" unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet") @@ -84,20 +77,13 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput): class MainModelField(BaseModel): """Main model field""" - model_name: str = Field(description="Name of the model") - base_model: BaseModelType = Field(description="Base model") - model_type: ModelType = Field(description="Model Type") - - model_config = ConfigDict(protected_namespaces=()) + key: str = Field(description="Model key") class LoRAModelField(BaseModel): """LoRA model field""" - model_name: str = Field(description="Name of the LoRA model") - base_model: BaseModelType = Field(description="Base model") - - model_config = ConfigDict(protected_namespaces=()) + key: str = Field(description="LoRA model key") @invocation( @@ -105,7 +91,7 @@ class LoRAModelField(BaseModel): title="Main Model", tags=["model"], category="model", - version="1.0.0", + version="1.0.1", ) class MainModelLoaderInvocation(BaseInvocation): """Loads a main model, outputting its submodels.""" @@ -114,85 +100,40 @@ class MainModelLoaderInvocation(BaseInvocation): # TODO: precision? def invoke(self, context: InvocationContext) -> ModelLoaderOutput: - base_model = self.model.base_model - model_name = self.model.model_name - model_type = ModelType.Main + key = self.model.key # TODO: not found exceptions - if not context.services.model_manager.model_exists( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ): - raise Exception(f"Unknown {base_model} {model_type} model: {model_name}") - - """ - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.Tokenizer, - ): - raise Exception( - f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted" - ) - - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.TextEncoder, - ): - raise Exception( - f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted" - ) - - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.UNet, - ): - raise Exception( - f"Failed to find unet submodel from {self.model_name}! Check if model corrupted" - ) - """ + if not context.models.exists(key): + raise Exception(f"Unknown model {key}") return ModelLoaderOutput( unet=UNetField( unet=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.UNet, + key=key, + submodel_type=SubModelType.UNet, ), scheduler=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.Scheduler, + key=key, + submodel_type=SubModelType.Scheduler, ), loras=[], ), clip=ClipField( tokenizer=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.Tokenizer, + key=key, + submodel_type=SubModelType.Tokenizer, ), text_encoder=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.TextEncoder, + key=key, + submodel_type=SubModelType.TextEncoder, ), loras=[], skipped_layers=0, ), vae=VaeField( vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.Vae, + key=key, + submodel_type=SubModelType.Vae, ), ), ) @@ -206,7 +147,7 @@ class LoraLoaderOutput(BaseInvocationOutput): clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP") -@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.0") +@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.1") class LoraLoaderInvocation(BaseInvocation): """Apply selected lora to unet and text_encoder.""" @@ -229,21 +170,16 @@ def invoke(self, context: InvocationContext) -> LoraLoaderOutput: if self.lora is None: raise Exception("No LoRA provided") - base_model = self.lora.base_model - lora_name = self.lora.model_name + lora_key = self.lora.key - if not context.services.model_manager.model_exists( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, - ): - raise Exception(f"Unkown lora name: {lora_name}!") + if not context.models.exists(lora_key): + raise Exception(f"Unkown lora: {lora_key}!") - if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras): - raise Exception(f'Lora "{lora_name}" already applied to unet') + if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras): + raise Exception(f'Lora "{lora_key}" already applied to unet') - if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras): - raise Exception(f'Lora "{lora_name}" already applied to clip') + if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras): + raise Exception(f'Lora "{lora_key}" already applied to clip') output = LoraLoaderOutput() @@ -251,10 +187,8 @@ def invoke(self, context: InvocationContext) -> LoraLoaderOutput: output.unet = copy.deepcopy(self.unet) output.unet.loras.append( LoraInfo( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, - submodel=None, + key=lora_key, + submodel_type=None, weight=self.weight, ) ) @@ -263,10 +197,8 @@ def invoke(self, context: InvocationContext) -> LoraLoaderOutput: output.clip = copy.deepcopy(self.clip) output.clip.loras.append( LoraInfo( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, - submodel=None, + key=lora_key, + submodel_type=None, weight=self.weight, ) ) @@ -288,7 +220,7 @@ class SDXLLoraLoaderOutput(BaseInvocationOutput): title="SDXL LoRA", tags=["lora", "model"], category="model", - version="1.0.0", + version="1.0.1", ) class SDXLLoraLoaderInvocation(BaseInvocation): """Apply selected lora to unet and text_encoder.""" @@ -318,24 +250,19 @@ def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: if self.lora is None: raise Exception("No LoRA provided") - base_model = self.lora.base_model - lora_name = self.lora.model_name + lora_key = self.lora.key - if not context.services.model_manager.model_exists( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, - ): - raise Exception(f"Unknown lora name: {lora_name}!") + if not context.models.exists(lora_key): + raise Exception(f"Unknown lora: {lora_key}!") - if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras): - raise Exception(f'Lora "{lora_name}" already applied to unet') + if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras): + raise Exception(f'Lora "{lora_key}" already applied to unet') - if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras): - raise Exception(f'Lora "{lora_name}" already applied to clip') + if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras): + raise Exception(f'Lora "{lora_key}" already applied to clip') - if self.clip2 is not None and any(lora.model_name == lora_name for lora in self.clip2.loras): - raise Exception(f'Lora "{lora_name}" already applied to clip2') + if self.clip2 is not None and any(lora.key == lora_key for lora in self.clip2.loras): + raise Exception(f'Lora "{lora_key}" already applied to clip2') output = SDXLLoraLoaderOutput() @@ -343,10 +270,8 @@ def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: output.unet = copy.deepcopy(self.unet) output.unet.loras.append( LoraInfo( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, - submodel=None, + key=lora_key, + submodel_type=None, weight=self.weight, ) ) @@ -355,10 +280,8 @@ def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: output.clip = copy.deepcopy(self.clip) output.clip.loras.append( LoraInfo( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, - submodel=None, + key=lora_key, + submodel_type=None, weight=self.weight, ) ) @@ -367,10 +290,8 @@ def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: output.clip2 = copy.deepcopy(self.clip2) output.clip2.loras.append( LoraInfo( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, - submodel=None, + key=lora_key, + submodel_type=None, weight=self.weight, ) ) @@ -381,13 +302,10 @@ def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: class VAEModelField(BaseModel): """Vae model field""" - model_name: str = Field(description="Name of the model") - base_model: BaseModelType = Field(description="Base model") - - model_config = ConfigDict(protected_namespaces=()) + key: str = Field(description="Model's key") -@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.0") +@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.1") class VaeLoaderInvocation(BaseInvocation): """Loads a VAE model, outputting a VaeLoaderOutput""" @@ -398,25 +316,12 @@ class VaeLoaderInvocation(BaseInvocation): ) def invoke(self, context: InvocationContext) -> VAEOutput: - base_model = self.vae_model.base_model - model_name = self.vae_model.model_name - model_type = ModelType.Vae - - if not context.services.model_manager.model_exists( - base_model=base_model, - model_name=model_name, - model_type=model_type, - ): - raise Exception(f"Unkown vae name: {model_name}!") - return VAEOutput( - vae=VaeField( - vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ) - ) - ) + key = self.vae_model.key + + if not context.models.exists(key): + raise Exception(f"Unkown vae: {key}!") + + return VAEOutput(vae=VaeField(vae=ModelInfo(key=key))) @invocation_output("seamless_output") diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index b1ee91e1cdf..335d3df292e 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -4,17 +4,15 @@ import torch from pydantic import field_validator -from invokeai.app.invocations.latent import LatentsField -from invokeai.app.shared.fields import FieldDescriptions +from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR +from invokeai.app.invocations.fields import FieldDescriptions, InputField, LatentsField, OutputField +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.misc import SEED_MAX from ...backend.util.devices import choose_torch_device, torch_dtype from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InputField, - InvocationContext, - OutputField, invocation, invocation_output, ) @@ -69,13 +67,13 @@ class NoiseOutput(BaseInvocationOutput): width: int = OutputField(description=FieldDescriptions.width) height: int = OutputField(description=FieldDescriptions.height) - -def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int): - return NoiseOutput( - noise=LatentsField(latents_name=latents_name, seed=seed), - width=latents.size()[3] * 8, - height=latents.size()[2] * 8, - ) + @classmethod + def build(cls, latents_name: str, latents: torch.Tensor, seed: int) -> "NoiseOutput": + return cls( + noise=LatentsField(latents_name=latents_name, seed=seed), + width=latents.size()[3] * LATENT_SCALE_FACTOR, + height=latents.size()[2] * LATENT_SCALE_FACTOR, + ) @invocation( @@ -96,13 +94,13 @@ class NoiseInvocation(BaseInvocation): ) width: int = InputField( default=512, - multiple_of=8, + multiple_of=LATENT_SCALE_FACTOR, gt=0, description=FieldDescriptions.width, ) height: int = InputField( default=512, - multiple_of=8, + multiple_of=LATENT_SCALE_FACTOR, gt=0, description=FieldDescriptions.height, ) @@ -124,6 +122,5 @@ def invoke(self, context: InvocationContext) -> NoiseOutput: seed=self.seed, use_cpu=self.use_cpu, ) - name = f"{context.graph_execution_state_id}__{self.id}" - context.services.latents.save(name, noise) - return build_noise_output(latents_name=name, latents=noise, seed=self.seed) + name = context.tensors.save(tensor=noise) + return NoiseOutput.build(latents_name=name, latents=noise, seed=self.seed) diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py deleted file mode 100644 index 759cfde700f..00000000000 --- a/invokeai/app/invocations/onnx.py +++ /dev/null @@ -1,508 +0,0 @@ -# Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779) - -import inspect - -# from contextlib import ExitStack -from typing import List, Literal, Union - -import numpy as np -import torch -from diffusers.image_processor import VaeImageProcessor -from pydantic import BaseModel, ConfigDict, Field, field_validator -from tqdm import tqdm - -from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin -from invokeai.app.shared.fields import FieldDescriptions -from invokeai.app.util.step_callback import stable_diffusion_step_callback -from invokeai.backend import BaseModelType, ModelType, SubModelType - -from ...backend.model_management import ONNXModelPatcher -from ...backend.stable_diffusion import PipelineIntermediateState -from ...backend.util import choose_torch_device -from ..util.ti_utils import extract_ti_triggers_from_prompt -from .baseinvocation import ( - BaseInvocation, - BaseInvocationOutput, - Input, - InputField, - InvocationContext, - OutputField, - UIComponent, - UIType, - WithMetadata, - invocation, - invocation_output, -) -from .controlnet_image_processors import ControlField -from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, build_latents_output, get_scheduler -from .model import ClipField, ModelInfo, UNetField, VaeField - -ORT_TO_NP_TYPE = { - "tensor(bool)": np.bool_, - "tensor(int8)": np.int8, - "tensor(uint8)": np.uint8, - "tensor(int16)": np.int16, - "tensor(uint16)": np.uint16, - "tensor(int32)": np.int32, - "tensor(uint32)": np.uint32, - "tensor(int64)": np.int64, - "tensor(uint64)": np.uint64, - "tensor(float16)": np.float16, - "tensor(float)": np.float32, - "tensor(double)": np.float64, -} - -PRECISION_VALUES = Literal[tuple(ORT_TO_NP_TYPE.keys())] - - -@invocation("prompt_onnx", title="ONNX Prompt (Raw)", tags=["prompt", "onnx"], category="conditioning", version="1.0.0") -class ONNXPromptInvocation(BaseInvocation): - prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea) - clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection) - - def invoke(self, context: InvocationContext) -> ConditioningOutput: - tokenizer_info = context.services.model_manager.get_model( - **self.clip.tokenizer.model_dump(), - ) - text_encoder_info = context.services.model_manager.get_model( - **self.clip.text_encoder.model_dump(), - ) - with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder: # , ExitStack() as stack: - loras = [ - ( - context.services.model_manager.get_model(**lora.model_dump(exclude={"weight"})).context.model, - lora.weight, - ) - for lora in self.clip.loras - ] - - ti_list = [] - for trigger in extract_ti_triggers_from_prompt(self.prompt): - name = trigger[1:-1] - try: - ti_list.append( - ( - name, - context.services.model_manager.get_model( - model_name=name, - base_model=self.clip.text_encoder.base_model, - model_type=ModelType.TextualInversion, - ).context.model, - ) - ) - except Exception: - # print(e) - # import traceback - # print(traceback.format_exc()) - print(f'Warn: trigger: "{trigger}" not found') - if loras or ti_list: - text_encoder.release_session() - with ( - ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras), - ONNXModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager), - ): - text_encoder.create_session() - - # copy from - # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L153 - text_inputs = tokenizer( - self.prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="np", - ) - text_input_ids = text_inputs.input_ids - """ - untruncated_ids = tokenizer(prompt, padding="max_length", return_tensors="np").input_ids - - if not np.array_equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - """ - - prompt_embeds = text_encoder(input_ids=text_input_ids.astype(np.int32))[0] - - conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" - - # TODO: hacky but works ;D maybe rename latents somehow? - context.services.latents.save(conditioning_name, (prompt_embeds, None)) - - return ConditioningOutput( - conditioning=ConditioningField( - conditioning_name=conditioning_name, - ), - ) - - -# Text to image -@invocation( - "t2l_onnx", - title="ONNX Text to Latents", - tags=["latents", "inference", "txt2img", "onnx"], - category="latents", - version="1.0.0", -) -class ONNXTextToLatentsInvocation(BaseInvocation): - """Generates latents from conditionings.""" - - positive_conditioning: ConditioningField = InputField( - description=FieldDescriptions.positive_cond, - input=Input.Connection, - ) - negative_conditioning: ConditioningField = InputField( - description=FieldDescriptions.negative_cond, - input=Input.Connection, - ) - noise: LatentsField = InputField( - description=FieldDescriptions.noise, - input=Input.Connection, - ) - steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps) - cfg_scale: Union[float, List[float]] = InputField( - default=7.5, - ge=1, - description=FieldDescriptions.cfg_scale, - ) - scheduler: SAMPLER_NAME_VALUES = InputField( - default="euler", description=FieldDescriptions.scheduler, input=Input.Direct, ui_type=UIType.Scheduler - ) - precision: PRECISION_VALUES = InputField(default="tensor(float16)", description=FieldDescriptions.precision) - unet: UNetField = InputField( - description=FieldDescriptions.unet, - input=Input.Connection, - ) - control: Union[ControlField, list[ControlField]] = InputField( - default=None, - description=FieldDescriptions.control, - ) - # seamless: bool = InputField(default=False, description="Whether or not to generate an image that can tile without seams", ) - # seamless_axes: str = InputField(default="", description="The axes to tile the image on, 'x' and/or 'y'") - - @field_validator("cfg_scale") - def ge_one(cls, v): - """validate that all cfg_scale values are >= 1""" - if isinstance(v, list): - for i in v: - if i < 1: - raise ValueError("cfg_scale must be greater than 1") - else: - if v < 1: - raise ValueError("cfg_scale must be greater than 1") - return v - - # based on - # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375 - def invoke(self, context: InvocationContext) -> LatentsOutput: - c, _ = context.services.latents.get(self.positive_conditioning.conditioning_name) - uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name) - graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) - source_node_id = graph_execution_state.prepared_source_mapping[self.id] - if isinstance(c, torch.Tensor): - c = c.cpu().numpy() - if isinstance(uc, torch.Tensor): - uc = uc.cpu().numpy() - device = torch.device(choose_torch_device()) - prompt_embeds = np.concatenate([uc, c]) - - latents = context.services.latents.get(self.noise.latents_name) - if isinstance(latents, torch.Tensor): - latents = latents.cpu().numpy() - - # TODO: better execution device handling - latents = latents.astype(ORT_TO_NP_TYPE[self.precision]) - - # get the initial random noise unless the user supplied it - do_classifier_free_guidance = True - # latents_dtype = prompt_embeds.dtype - # latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8) - # if latents.shape != latents_shape: - # raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - - scheduler = get_scheduler( - context=context, - scheduler_info=self.unet.scheduler, - scheduler_name=self.scheduler, - seed=0, # TODO: refactor this node - ) - - def torch2numpy(latent: torch.Tensor): - return latent.cpu().numpy() - - def numpy2torch(latent, device): - return torch.from_numpy(latent).to(device) - - def dispatch_progress( - self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState - ) -> None: - stable_diffusion_step_callback( - context=context, - intermediate_state=intermediate_state, - node=self.model_dump(), - source_node_id=source_node_id, - ) - - scheduler.set_timesteps(self.steps) - latents = latents * np.float64(scheduler.init_noise_sigma) - - extra_step_kwargs = {} - if "eta" in set(inspect.signature(scheduler.step).parameters.keys()): - extra_step_kwargs.update( - eta=0.0, - ) - - unet_info = context.services.model_manager.get_model(**self.unet.unet.model_dump()) - - with unet_info as unet: # , ExitStack() as stack: - # loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras] - loras = [ - ( - context.services.model_manager.get_model(**lora.model_dump(exclude={"weight"})).context.model, - lora.weight, - ) - for lora in self.unet.loras - ] - - if loras: - unet.release_session() - with ONNXModelPatcher.apply_lora_unet(unet, loras): - # TODO: - _, _, h, w = latents.shape - unet.create_session(h, w) - - timestep_dtype = next( - (input.type for input in unet.session.get_inputs() if input.name == "timestep"), "tensor(float16)" - ) - timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] - for i in tqdm(range(len(scheduler.timesteps))): - t = scheduler.timesteps[i] - # expand the latents if we are doing classifier free guidance - latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = scheduler.scale_model_input(numpy2torch(latent_model_input, device), t) - latent_model_input = latent_model_input.cpu().numpy() - - # predict the noise residual - timestep = np.array([t], dtype=timestep_dtype) - noise_pred = unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds) - noise_pred = noise_pred[0] - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) - noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - scheduler_output = scheduler.step( - numpy2torch(noise_pred, device), t, numpy2torch(latents, device), **extra_step_kwargs - ) - latents = torch2numpy(scheduler_output.prev_sample) - - state = PipelineIntermediateState( - run_id="test", step=i, timestep=timestep, latents=scheduler_output.prev_sample - ) - dispatch_progress(self, context=context, source_node_id=source_node_id, intermediate_state=state) - - # call the callback, if provided - # if callback is not None and i % callback_steps == 0: - # callback(i, t, latents) - - torch.cuda.empty_cache() - - name = f"{context.graph_execution_state_id}__{self.id}" - context.services.latents.save(name, latents) - return build_latents_output(latents_name=name, latents=torch.from_numpy(latents)) - - -# Latent to image -@invocation( - "l2i_onnx", - title="ONNX Latents to Image", - tags=["latents", "image", "vae", "onnx"], - category="image", - version="1.2.0", -) -class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata): - """Generates an image from latents.""" - - latents: LatentsField = InputField( - description=FieldDescriptions.denoised_latents, - input=Input.Connection, - ) - vae: VaeField = InputField( - description=FieldDescriptions.vae, - input=Input.Connection, - ) - # tiled: bool = InputField(default=False, description="Decode latents by overlaping tiles(less memory consumption)") - - def invoke(self, context: InvocationContext) -> ImageOutput: - latents = context.services.latents.get(self.latents.latents_name) - - if self.vae.vae.submodel != SubModelType.VaeDecoder: - raise Exception(f"Expected vae_decoder, found: {self.vae.vae.model_type}") - - vae_info = context.services.model_manager.get_model( - **self.vae.vae.model_dump(), - ) - - # clear memory as vae decode can request a lot - torch.cuda.empty_cache() - - with vae_info as vae: - vae.create_session() - - # copied from - # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L427 - latents = 1 / 0.18215 * latents - # image = self.vae_decoder(latent_sample=latents)[0] - # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 - image = np.concatenate([vae(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]) - - image = np.clip(image / 2 + 0.5, 0, 1) - image = image.transpose((0, 2, 3, 1)) - image = VaeImageProcessor.numpy_to_pil(image)[0] - - torch.cuda.empty_cache() - - image_dto = context.services.images.create( - image=image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) - - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) - - -@invocation_output("model_loader_output_onnx") -class ONNXModelLoaderOutput(BaseInvocationOutput): - """Model loader output""" - - unet: UNetField = OutputField(default=None, description=FieldDescriptions.unet, title="UNet") - clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP") - vae_decoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Decoder") - vae_encoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Encoder") - - -class OnnxModelField(BaseModel): - """Onnx model field""" - - model_name: str = Field(description="Name of the model") - base_model: BaseModelType = Field(description="Base model") - model_type: ModelType = Field(description="Model Type") - - model_config = ConfigDict(protected_namespaces=()) - - -@invocation("onnx_model_loader", title="ONNX Main Model", tags=["onnx", "model"], category="model", version="1.0.0") -class OnnxModelLoaderInvocation(BaseInvocation): - """Loads a main model, outputting its submodels.""" - - model: OnnxModelField = InputField( - description=FieldDescriptions.onnx_main_model, input=Input.Direct, ui_type=UIType.ONNXModel - ) - - def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput: - base_model = self.model.base_model - model_name = self.model.model_name - model_type = ModelType.ONNX - - # TODO: not found exceptions - if not context.services.model_manager.model_exists( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ): - raise Exception(f"Unknown {base_model} {model_type} model: {model_name}") - - """ - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.Tokenizer, - ): - raise Exception( - f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted" - ) - - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.TextEncoder, - ): - raise Exception( - f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted" - ) - - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.UNet, - ): - raise Exception( - f"Failed to find unet submodel from {self.model_name}! Check if model corrupted" - ) - """ - - return ONNXModelLoaderOutput( - unet=UNetField( - unet=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.UNet, - ), - scheduler=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.Scheduler, - ), - loras=[], - ), - clip=ClipField( - tokenizer=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.Tokenizer, - ), - text_encoder=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.TextEncoder, - ), - loras=[], - skipped_layers=0, - ), - vae_decoder=VaeField( - vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.VaeDecoder, - ), - ), - vae_encoder=VaeField( - vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.VaeEncoder, - ), - ), - ) diff --git a/invokeai/app/invocations/param_easing.py b/invokeai/app/invocations/param_easing.py index dccd18f754b..6845637de92 100644 --- a/invokeai/app/invocations/param_easing.py +++ b/invokeai/app/invocations/param_easing.py @@ -40,8 +40,10 @@ from matplotlib.ticker import MaxNLocator from invokeai.app.invocations.primitives import FloatCollectionOutput +from invokeai.app.services.shared.invocation_context import InvocationContext -from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation +from .baseinvocation import BaseInvocation, invocation +from .fields import InputField @invocation( @@ -109,7 +111,7 @@ def invoke(self, context: InvocationContext) -> FloatCollectionOutput: title="Step Param Easing", tags=["step", "easing"], category="step", - version="1.0.0", + version="1.0.1", ) class StepParamEasingInvocation(BaseInvocation): """Experimental per-step parameter easing for denoising steps""" @@ -148,19 +150,19 @@ def invoke(self, context: InvocationContext) -> FloatCollectionOutput: postlist = list(num_poststeps * [self.post_end_value]) if log_diagnostics: - context.services.logger.debug("start_step: " + str(start_step)) - context.services.logger.debug("end_step: " + str(end_step)) - context.services.logger.debug("num_easing_steps: " + str(num_easing_steps)) - context.services.logger.debug("num_presteps: " + str(num_presteps)) - context.services.logger.debug("num_poststeps: " + str(num_poststeps)) - context.services.logger.debug("prelist size: " + str(len(prelist))) - context.services.logger.debug("postlist size: " + str(len(postlist))) - context.services.logger.debug("prelist: " + str(prelist)) - context.services.logger.debug("postlist: " + str(postlist)) + context.logger.debug("start_step: " + str(start_step)) + context.logger.debug("end_step: " + str(end_step)) + context.logger.debug("num_easing_steps: " + str(num_easing_steps)) + context.logger.debug("num_presteps: " + str(num_presteps)) + context.logger.debug("num_poststeps: " + str(num_poststeps)) + context.logger.debug("prelist size: " + str(len(prelist))) + context.logger.debug("postlist size: " + str(len(postlist))) + context.logger.debug("prelist: " + str(prelist)) + context.logger.debug("postlist: " + str(postlist)) easing_class = EASING_FUNCTIONS_MAP[self.easing] if log_diagnostics: - context.services.logger.debug("easing class: " + str(easing_class)) + context.logger.debug("easing class: " + str(easing_class)) easing_list = [] if self.mirror: # "expected" mirroring # if number of steps is even, squeeze duration down to (number_of_steps)/2 @@ -171,7 +173,7 @@ def invoke(self, context: InvocationContext) -> FloatCollectionOutput: base_easing_duration = int(np.ceil(num_easing_steps / 2.0)) if log_diagnostics: - context.services.logger.debug("base easing duration: " + str(base_easing_duration)) + context.logger.debug("base easing duration: " + str(base_easing_duration)) even_num_steps = num_easing_steps % 2 == 0 # even number of steps easing_function = easing_class( start=self.start_value, @@ -183,14 +185,14 @@ def invoke(self, context: InvocationContext) -> FloatCollectionOutput: easing_val = easing_function.ease(step_index) base_easing_vals.append(easing_val) if log_diagnostics: - context.services.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val)) + context.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val)) if even_num_steps: mirror_easing_vals = list(reversed(base_easing_vals)) else: mirror_easing_vals = list(reversed(base_easing_vals[0:-1])) if log_diagnostics: - context.services.logger.debug("base easing vals: " + str(base_easing_vals)) - context.services.logger.debug("mirror easing vals: " + str(mirror_easing_vals)) + context.logger.debug("base easing vals: " + str(base_easing_vals)) + context.logger.debug("mirror easing vals: " + str(mirror_easing_vals)) easing_list = base_easing_vals + mirror_easing_vals # FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely @@ -225,12 +227,12 @@ def invoke(self, context: InvocationContext) -> FloatCollectionOutput: step_val = easing_function.ease(step_index) easing_list.append(step_val) if log_diagnostics: - context.services.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val)) + context.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val)) if log_diagnostics: - context.services.logger.debug("prelist size: " + str(len(prelist))) - context.services.logger.debug("easing_list size: " + str(len(easing_list))) - context.services.logger.debug("postlist size: " + str(len(postlist))) + context.logger.debug("prelist size: " + str(len(prelist))) + context.logger.debug("easing_list size: " + str(len(easing_list))) + context.logger.debug("postlist size: " + str(len(postlist))) param_list = prelist + easing_list + postlist diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index afe8ff06d9d..43422134829 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -1,20 +1,28 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) -from typing import Optional, Tuple +from typing import Optional import torch -from pydantic import BaseModel, Field -from invokeai.app.shared.fields import FieldDescriptions - -from .baseinvocation import ( - BaseInvocation, - BaseInvocationOutput, +from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR +from invokeai.app.invocations.fields import ( + ColorField, + ConditioningField, + DenoiseMaskField, + FieldDescriptions, + ImageField, Input, InputField, - InvocationContext, + LatentsField, OutputField, UIComponent, +) +from invokeai.app.services.images.images_common import ImageDTO +from invokeai.app.services.shared.invocation_context import InvocationContext + +from .baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, invocation, invocation_output, ) @@ -221,18 +229,6 @@ def invoke(self, context: InvocationContext) -> StringCollectionOutput: # region Image -class ImageField(BaseModel): - """An image primitive field""" - - image_name: str = Field(description="The name of the image") - - -class BoardField(BaseModel): - """A board primitive field""" - - board_id: str = Field(description="The id of the board") - - @invocation_output("image_output") class ImageOutput(BaseInvocationOutput): """Base class for nodes that output a single image""" @@ -241,6 +237,14 @@ class ImageOutput(BaseInvocationOutput): width: int = OutputField(description="The width of the image in pixels") height: int = OutputField(description="The height of the image in pixels") + @classmethod + def build(cls, image_dto: ImageDTO) -> "ImageOutput": + return cls( + image=ImageField(image_name=image_dto.image_name), + width=image_dto.width, + height=image_dto.height, + ) + @invocation_output("image_collection_output") class ImageCollectionOutput(BaseInvocationOutput): @@ -251,16 +255,14 @@ class ImageCollectionOutput(BaseInvocationOutput): ) -@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives", version="1.0.0") -class ImageInvocation( - BaseInvocation, -): +@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives", version="1.0.1") +class ImageInvocation(BaseInvocation): """An image primitive value""" image: ImageField = InputField(description="The image to load") def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + image = context.images.get_pil(self.image.image_name) return ImageOutput( image=ImageField(image_name=self.image.image_name), @@ -290,42 +292,40 @@ def invoke(self, context: InvocationContext) -> ImageCollectionOutput: # region DenoiseMask -class DenoiseMaskField(BaseModel): - """An inpaint mask field""" - - mask_name: str = Field(description="The name of the mask image") - masked_latents_name: Optional[str] = Field(default=None, description="The name of the masked image latents") - - @invocation_output("denoise_mask_output") class DenoiseMaskOutput(BaseInvocationOutput): """Base class for nodes that output a single image""" denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run") + @classmethod + def build(cls, mask_name: str, masked_latents_name: Optional[str] = None) -> "DenoiseMaskOutput": + return cls( + denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name), + ) + # endregion # region Latents -class LatentsField(BaseModel): - """A latents tensor primitive field""" - - latents_name: str = Field(description="The name of the latents") - seed: Optional[int] = Field(default=None, description="Seed used to generate this latents") - - @invocation_output("latents_output") class LatentsOutput(BaseInvocationOutput): """Base class for nodes that output a single latents tensor""" - latents: LatentsField = OutputField( - description=FieldDescriptions.latents, - ) + latents: LatentsField = OutputField(description=FieldDescriptions.latents) width: int = OutputField(description=FieldDescriptions.width) height: int = OutputField(description=FieldDescriptions.height) + @classmethod + def build(cls, latents_name: str, latents: torch.Tensor, seed: Optional[int] = None) -> "LatentsOutput": + return cls( + latents=LatentsField(latents_name=latents_name, seed=seed), + width=latents.size()[3] * LATENT_SCALE_FACTOR, + height=latents.size()[2] * LATENT_SCALE_FACTOR, + ) + @invocation_output("latents_collection_output") class LatentsCollectionOutput(BaseInvocationOutput): @@ -337,7 +337,7 @@ class LatentsCollectionOutput(BaseInvocationOutput): @invocation( - "latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives", version="1.0.0" + "latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives", version="1.0.1" ) class LatentsInvocation(BaseInvocation): """A latents tensor primitive value""" @@ -345,9 +345,9 @@ class LatentsInvocation(BaseInvocation): latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.services.latents.get(self.latents.latents_name) + latents = context.tensors.load(self.latents.latents_name) - return build_latents_output(self.latents.latents_name, latents) + return LatentsOutput.build(self.latents.latents_name, latents) @invocation( @@ -368,31 +368,11 @@ def invoke(self, context: InvocationContext) -> LatentsCollectionOutput: return LatentsCollectionOutput(collection=self.collection) -def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int] = None): - return LatentsOutput( - latents=LatentsField(latents_name=latents_name, seed=seed), - width=latents.size()[3] * 8, - height=latents.size()[2] * 8, - ) - - # endregion # region Color -class ColorField(BaseModel): - """A color primitive field""" - - r: int = Field(ge=0, le=255, description="The red component") - g: int = Field(ge=0, le=255, description="The green component") - b: int = Field(ge=0, le=255, description="The blue component") - a: int = Field(ge=0, le=255, description="The alpha component") - - def tuple(self) -> Tuple[int, int, int, int]: - return (self.r, self.g, self.b, self.a) - - @invocation_output("color_output") class ColorOutput(BaseInvocationOutput): """Base class for nodes that output a single color""" @@ -424,18 +404,16 @@ def invoke(self, context: InvocationContext) -> ColorOutput: # region Conditioning -class ConditioningField(BaseModel): - """A conditioning tensor primitive value""" - - conditioning_name: str = Field(description="The name of conditioning tensor") - - @invocation_output("conditioning_output") class ConditioningOutput(BaseInvocationOutput): """Base class for nodes that output a single conditioning tensor""" conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond) + @classmethod + def build(cls, conditioning_name: str) -> "ConditioningOutput": + return cls(conditioning=ConditioningField(conditioning_name=conditioning_name)) + @invocation_output("conditioning_collection_output") class ConditioningCollectionOutput(BaseInvocationOutput): diff --git a/invokeai/app/invocations/prompt.py b/invokeai/app/invocations/prompt.py index 4778d980771..234743a0035 100644 --- a/invokeai/app/invocations/prompt.py +++ b/invokeai/app/invocations/prompt.py @@ -6,8 +6,10 @@ from pydantic import field_validator from invokeai.app.invocations.primitives import StringCollectionOutput +from invokeai.app.services.shared.invocation_context import InvocationContext -from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, invocation +from .baseinvocation import BaseInvocation, invocation +from .fields import InputField, UIComponent @invocation( diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 68076fdfeb1..0df27c00110 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -1,14 +1,10 @@ -from invokeai.app.shared.fields import FieldDescriptions +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.model_manager import SubModelType -from ...backend.model_management import ModelType, SubModelType from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - Input, - InputField, - InvocationContext, - OutputField, - UIType, invocation, invocation_output, ) @@ -34,7 +30,7 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput): vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE") -@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.0") +@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.1") class SDXLModelLoaderInvocation(BaseInvocation): """Loads an sdxl base model, outputting its submodels.""" @@ -44,72 +40,52 @@ class SDXLModelLoaderInvocation(BaseInvocation): # TODO: precision? def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput: - base_model = self.model.base_model - model_name = self.model.model_name - model_type = ModelType.Main + model_key = self.model.key # TODO: not found exceptions - if not context.services.model_manager.model_exists( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ): - raise Exception(f"Unknown {base_model} {model_type} model: {model_name}") + if not context.models.exists(model_key): + raise Exception(f"Unknown model: {model_key}") return SDXLModelLoaderOutput( unet=UNetField( unet=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.UNet, + key=model_key, + submodel_type=SubModelType.UNet, ), scheduler=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.Scheduler, + key=model_key, + submodel_type=SubModelType.Scheduler, ), loras=[], ), clip=ClipField( tokenizer=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.Tokenizer, + key=model_key, + submodel_type=SubModelType.Tokenizer, ), text_encoder=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.TextEncoder, + key=model_key, + submodel_type=SubModelType.TextEncoder, ), loras=[], skipped_layers=0, ), clip2=ClipField( tokenizer=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.Tokenizer2, + key=model_key, + submodel_type=SubModelType.Tokenizer2, ), text_encoder=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.TextEncoder2, + key=model_key, + submodel_type=SubModelType.TextEncoder2, ), loras=[], skipped_layers=0, ), vae=VaeField( vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.Vae, + key=model_key, + submodel_type=SubModelType.Vae, ), ), ) @@ -120,7 +96,7 @@ def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput: title="SDXL Refiner Model", tags=["model", "sdxl", "refiner"], category="model", - version="1.0.0", + version="1.0.1", ) class SDXLRefinerModelLoaderInvocation(BaseInvocation): """Loads an sdxl refiner model, outputting its submodels.""" @@ -133,56 +109,40 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation): # TODO: precision? def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput: - base_model = self.model.base_model - model_name = self.model.model_name - model_type = ModelType.Main + model_key = self.model.key # TODO: not found exceptions - if not context.services.model_manager.model_exists( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ): - raise Exception(f"Unknown {base_model} {model_type} model: {model_name}") + if not context.models.exists(model_key): + raise Exception(f"Unknown model: {model_key}") return SDXLRefinerModelLoaderOutput( unet=UNetField( unet=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.UNet, + key=model_key, + submodel_type=SubModelType.UNet, ), scheduler=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.Scheduler, + key=model_key, + submodel_type=SubModelType.Scheduler, ), loras=[], ), clip2=ClipField( tokenizer=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.Tokenizer2, + key=model_key, + submodel_type=SubModelType.Tokenizer2, ), text_encoder=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.TextEncoder2, + key=model_key, + submodel_type=SubModelType.TextEncoder2, ), loras=[], skipped_layers=0, ), vae=VaeField( vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.Vae, + key=model_key, + submodel_type=SubModelType.Vae, ), ), ) diff --git a/invokeai/app/invocations/strings.py b/invokeai/app/invocations/strings.py index 3466206b377..182c976cd77 100644 --- a/invokeai/app/invocations/strings.py +++ b/invokeai/app/invocations/strings.py @@ -2,16 +2,15 @@ import re +from invokeai.app.services.shared.invocation_context import InvocationContext + from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InputField, - InvocationContext, - OutputField, - UIComponent, invocation, invocation_output, ) +from .fields import InputField, OutputField, UIComponent from .primitives import StringOutput diff --git a/invokeai/app/invocations/t2i_adapter.py b/invokeai/app/invocations/t2i_adapter.py index e055d23903f..0f1e251bb36 100644 --- a/invokeai/app/invocations/t2i_adapter.py +++ b/invokeai/app/invocations/t2i_adapter.py @@ -1,29 +1,21 @@ from typing import Union -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from pydantic import BaseModel, Field, field_validator, model_validator from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - Input, - InputField, - InvocationContext, - OutputField, invocation, invocation_output, ) from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES -from invokeai.app.invocations.primitives import ImageField +from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights -from invokeai.app.shared.fields import FieldDescriptions -from invokeai.backend.model_management.models.base import BaseModelType +from invokeai.app.services.shared.invocation_context import InvocationContext class T2IAdapterModelField(BaseModel): - model_name: str = Field(description="Name of the T2I-Adapter model") - base_model: BaseModelType = Field(description="Base model") - - model_config = ConfigDict(protected_namespaces=()) + key: str = Field(description="Model record key for the T2I-Adapter model") class T2IAdapterField(BaseModel): diff --git a/invokeai/app/invocations/tiles.py b/invokeai/app/invocations/tiles.py index e51f891a8db..cb5373bbf75 100644 --- a/invokeai/app/invocations/tiles.py +++ b/invokeai/app/invocations/tiles.py @@ -8,16 +8,12 @@ BaseInvocation, BaseInvocationOutput, Classification, - Input, - InputField, - InvocationContext, - OutputField, - WithMetadata, invocation, invocation_output, ) -from invokeai.app.invocations.primitives import ImageField, ImageOutput -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin +from invokeai.app.invocations.fields import ImageField, Input, InputField, OutputField, WithBoard, WithMetadata +from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.tiles.tiles import ( calc_tiles_even_split, calc_tiles_min_overlap, @@ -236,7 +232,7 @@ def invoke(self, context: InvocationContext) -> PairTileImageOutput: version="1.1.0", classification=Classification.Beta, ) -class MergeTilesToImageInvocation(BaseInvocation, WithMetadata): +class MergeTilesToImageInvocation(BaseInvocation, WithMetadata, WithBoard): """Merge multiple tile images into a single image.""" # Inputs @@ -268,7 +264,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: # existed in memory at an earlier point in the graph. tile_np_images: list[np.ndarray] = [] for image in images: - pil_image = context.services.images.get_pil_image(image.image_name) + pil_image = context.images.get_pil(image.image_name) pil_image = pil_image.convert("RGB") tile_np_images.append(np.array(pil_image)) @@ -291,18 +287,5 @@ def invoke(self, context: InvocationContext) -> ImageOutput: # Convert into a PIL image and save pil_image = Image.fromarray(np_image) - image_dto = context.services.images.create( - image=pil_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + image_dto = context.images.save(image=pil_image) + return ImageOutput.build(image_dto) diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index 5f715c1a7ed..2e2a6ce8813 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -8,13 +8,15 @@ from PIL import Image from pydantic import ConfigDict -from invokeai.app.invocations.primitives import ImageField, ImageOutput -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin +from invokeai.app.invocations.fields import ImageField +from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN from invokeai.backend.util.devices import choose_torch_device -from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, invocation +from .baseinvocation import BaseInvocation, invocation +from .fields import InputField, WithBoard, WithMetadata # TODO: Populate this from disk? # TODO: Use model manager to load? @@ -29,8 +31,8 @@ from torch import mps -@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.0") -class ESRGANInvocation(BaseInvocation, WithMetadata): +@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.1") +class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard): """Upscales an image using RealESRGAN.""" image: ImageField = InputField(description="The input image") @@ -42,8 +44,8 @@ class ESRGANInvocation(BaseInvocation, WithMetadata): model_config = ConfigDict(protected_namespaces=()) def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) - models_path = context.services.configuration.models_path + image = context.images.get_pil(self.image.image_name) + models_path = context.config.get().models_path rrdbnet_model = None netscale = None @@ -87,7 +89,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: netscale = 2 else: msg = f"Invalid RealESRGAN model: {self.model_name}" - context.services.logger.error(msg) + context.logger.error(msg) raise ValueError(msg) esrgan_model_path = Path(f"core/upscaling/realesrgan/{self.model_name}") @@ -110,19 +112,6 @@ def invoke(self, context: InvocationContext) -> ImageOutput: if choose_torch_device() == torch.device("mps"): mps.empty_cache() - image_dto = context.services.images.create( - image=pil_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=pil_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) diff --git a/invokeai/app/services/config/config_base.py b/invokeai/app/services/config/config_base.py index a304b38a955..983df6b4684 100644 --- a/invokeai/app/services/config/config_base.py +++ b/invokeai/app/services/config/config_base.py @@ -27,11 +27,11 @@ class InvokeAISettings(BaseSettings): """Runtime configuration settings in which default values are read from an omegaconf .yaml file.""" initconf: ClassVar[Optional[DictConfig]] = None - argparse_groups: ClassVar[Dict] = {} + argparse_groups: ClassVar[Dict[str, Any]] = {} model_config = SettingsConfigDict(env_file_encoding="utf-8", arbitrary_types_allowed=True, case_sensitive=True) - def parse_args(self, argv: Optional[list] = sys.argv[1:]): + def parse_args(self, argv: Optional[List[str]] = sys.argv[1:]) -> None: """Call to parse command-line arguments.""" parser = self.get_parser() opt, unknown_opts = parser.parse_known_args(argv) @@ -68,7 +68,7 @@ def to_yaml(self) -> str: return OmegaConf.to_yaml(conf) @classmethod - def add_parser_arguments(cls, parser): + def add_parser_arguments(cls, parser) -> None: """Dynamically create arguments for a settings parser.""" if "type" in get_type_hints(cls): settings_stanza = get_args(get_type_hints(cls)["type"])[0] @@ -117,7 +117,8 @@ def cmd_name(cls, command_field: str = "type") -> str: """Return the category of a setting.""" hints = get_type_hints(cls) if command_field in hints: - return get_args(hints[command_field])[0] + result: str = get_args(hints[command_field])[0] + return result else: return "Uncategorized" @@ -158,7 +159,7 @@ def _excluded_from_yaml(cls) -> List[str]: ] @classmethod - def add_field_argument(cls, command_parser, name: str, field, default_override=None): + def add_field_argument(cls, command_parser, name: str, field, default_override=None) -> None: """Add the argparse arguments for a setting parser.""" field_type = get_type_hints(cls).get(name) default = ( diff --git a/invokeai/app/services/config/config_common.py b/invokeai/app/services/config/config_common.py index d11bcabcf9c..27a0f859c23 100644 --- a/invokeai/app/services/config/config_common.py +++ b/invokeai/app/services/config/config_common.py @@ -21,7 +21,7 @@ class PagingArgumentParser(argparse.ArgumentParser): It also supports reading defaults from an init file. """ - def print_help(self, file=None): + def print_help(self, file=None) -> None: text = self.format_help() pydoc.pager(text) diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 132afc22722..2af775372dd 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -173,7 +173,7 @@ class InvokeBatch(InvokeAISettings): import os from pathlib import Path -from typing import Any, ClassVar, Dict, List, Literal, Optional, Union +from typing import Any, ClassVar, Dict, List, Literal, Optional from omegaconf import DictConfig, OmegaConf from pydantic import Field @@ -185,7 +185,9 @@ class InvokeBatch(InvokeAISettings): INIT_FILE = Path("invokeai.yaml") DB_FILE = Path("invokeai.db") LEGACY_INIT_FILE = Path("invokeai.init") -DEFAULT_MAX_VRAM = 0.5 +DEFAULT_RAM_CACHE = 10.0 +DEFAULT_VRAM_CACHE = 0.25 +DEFAULT_CONVERT_CACHE = 20.0 class Categories(object): @@ -237,6 +239,7 @@ class InvokeAIAppConfig(InvokeAISettings): autoimport_dir : Path = Field(default=Path('autoimport'), description='Path to a directory of models files to be imported on startup.', json_schema_extra=Categories.Paths) conf_path : Path = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Paths) models_dir : Path = Field(default=Path('models'), description='Path to the models directory', json_schema_extra=Categories.Paths) + convert_cache_dir : Path = Field(default=Path('models/.cache'), description='Path to the converted models cache directory', json_schema_extra=Categories.Paths) legacy_conf_dir : Path = Field(default=Path('configs/stable-diffusion'), description='Path to directory of legacy checkpoint config files', json_schema_extra=Categories.Paths) db_dir : Path = Field(default=Path('databases'), description='Path to InvokeAI databases directory', json_schema_extra=Categories.Paths) outdir : Path = Field(default=Path('outputs'), description='Default folder for output images', json_schema_extra=Categories.Paths) @@ -260,8 +263,10 @@ class InvokeAIAppConfig(InvokeAISettings): version : bool = Field(default=False, description="Show InvokeAI version and exit", json_schema_extra=Categories.Other) # CACHE - ram : float = Field(default=7.5, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", json_schema_extra=Categories.ModelCache, ) - vram : float = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", json_schema_extra=Categories.ModelCache, ) + ram : float = Field(default=DEFAULT_RAM_CACHE, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", json_schema_extra=Categories.ModelCache, ) + vram : float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", json_schema_extra=Categories.ModelCache, ) + convert_cache : float = Field(default=DEFAULT_CONVERT_CACHE, ge=0, description="Maximum size of on-disk converted models cache (GB)", json_schema_extra=Categories.ModelCache) + lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", json_schema_extra=Categories.ModelCache, ) log_memory_usage : bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.", json_schema_extra=Categories.ModelCache) @@ -404,6 +409,11 @@ def models_path(self) -> Path: """Path to the models directory.""" return self._resolve(self.models_dir) + @property + def models_convert_cache_path(self) -> Path: + """Path to the converted cache models directory.""" + return self._resolve(self.convert_cache_dir) + @property def custom_nodes_path(self) -> Path: """Path to the custom nodes directory.""" @@ -433,15 +443,20 @@ def invisible_watermark(self) -> bool: return True @property - def ram_cache_size(self) -> Union[Literal["auto"], float]: - """Return the ram cache size using the legacy or modern setting.""" + def ram_cache_size(self) -> float: + """Return the ram cache size using the legacy or modern setting (GB).""" return self.max_cache_size or self.ram @property - def vram_cache_size(self) -> Union[Literal["auto"], float]: - """Return the vram cache size using the legacy or modern setting.""" + def vram_cache_size(self) -> float: + """Return the vram cache size using the legacy or modern setting (GB).""" return self.max_vram_cache_size or self.vram + @property + def convert_cache_size(self) -> float: + """Return the convert cache size on disk (GB).""" + return self.convert_cache + @property def use_cpu(self) -> bool: """Return true if the device is set to CPU or the always_use_cpu flag is set.""" diff --git a/invokeai/app/services/download/download_base.py b/invokeai/app/services/download/download_base.py index f854f64f585..2ac13b825fe 100644 --- a/invokeai/app/services/download/download_base.py +++ b/invokeai/app/services/download/download_base.py @@ -260,3 +260,16 @@ def cancel_job(self, job: DownloadJob) -> None: def join(self) -> None: """Wait until all jobs are off the queue.""" pass + + @abstractmethod + def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob: + """Wait until the indicated download job has reached a terminal state. + + This will block until the indicated install job has completed, + been cancelled, or errored out. + + :param job: The job to wait on. + :param timeout: Wait up to indicated number of seconds. Raise a TimeoutError if + the job hasn't completed within the indicated time. + """ + pass diff --git a/invokeai/app/services/download/download_default.py b/invokeai/app/services/download/download_default.py index 7613c0893fc..6d5cedbcad8 100644 --- a/invokeai/app/services/download/download_default.py +++ b/invokeai/app/services/download/download_default.py @@ -4,10 +4,11 @@ import os import re import threading +import time import traceback from pathlib import Path from queue import Empty, PriorityQueue -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Set import requests from pydantic.networks import AnyHttpUrl @@ -48,11 +49,12 @@ def __init__( :param max_parallel_dl: Number of simultaneous downloads allowed [5]. :param requests_session: Optional requests.sessions.Session object, for unit tests. """ - self._jobs = {} + self._jobs: Dict[int, DownloadJob] = {} self._next_job_id = 0 - self._queue = PriorityQueue() + self._queue: PriorityQueue[DownloadJob] = PriorityQueue() self._stop_event = threading.Event() - self._worker_pool = set() + self._job_completed_event = threading.Event() + self._worker_pool: Set[threading.Thread] = set() self._lock = threading.Lock() self._logger = InvokeAILogger.get_logger("DownloadQueueService") self._event_bus = event_bus @@ -188,6 +190,16 @@ def cancel_all_jobs(self) -> None: if not job.in_terminal_state: self.cancel_job(job) + def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob: + """Block until the indicated job has reached terminal state, or when timeout limit reached.""" + start = time.time() + while not job.in_terminal_state: + if self._job_completed_event.wait(timeout=5): # in case we miss an event + self._job_completed_event.clear() + if timeout > 0 and time.time() - start > timeout: + raise TimeoutError("Timeout exceeded") + return job + def _start_workers(self, max_workers: int) -> None: """Start the requested number of worker threads.""" self._stop_event.clear() @@ -223,6 +235,7 @@ def _download_next_item(self) -> None: finally: job.job_ended = get_iso_timestamp() + self._job_completed_event.set() # signal a change to terminal state self._queue.task_done() self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.") @@ -407,11 +420,11 @@ def _cleanup_cancelled_job(self, job: DownloadJob) -> None: # Example on_progress event handler to display a TQDM status bar # Activate with: -# download_service.download('http://foo.bar/baz', '/tmp', on_progress=TqdmProgress().job_update +# download_service.download(DownloadJob('http://foo.bar/baz', '/tmp', on_progress=TqdmProgress().update)) class TqdmProgress(object): """TQDM-based progress bar object to use in on_progress handlers.""" - _bars: Dict[int, tqdm] # the tqdm object + _bars: Dict[int, tqdm] # type: ignore _last: Dict[int, int] # last bytes downloaded def __init__(self) -> None: # noqa D107 diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index e9365f33495..90d9068b88c 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -11,8 +11,7 @@ SessionQueueStatus, ) from invokeai.app.util.misc import get_timestamp -from invokeai.backend.model_management.model_manager import ModelInfo -from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType +from invokeai.backend.model_manager import AnyModelConfig class EventServiceBase: @@ -55,7 +54,7 @@ def emit_generator_progress( queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str, - node: dict, + node_id: str, source_node_id: str, progress_image: Optional[ProgressImage], step: int, @@ -70,7 +69,7 @@ def emit_generator_progress( "queue_item_id": queue_item_id, "queue_batch_id": queue_batch_id, "graph_execution_state_id": graph_execution_state_id, - "node_id": node.get("id"), + "node_id": node_id, "source_node_id": source_node_id, "progress_image": progress_image.model_dump() if progress_image is not None else None, "step": step, @@ -171,10 +170,7 @@ def emit_model_load_started( queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: SubModelType, + model_config: AnyModelConfig, ) -> None: """Emitted when a model is requested""" self.__emit_queue_event( @@ -184,10 +180,7 @@ def emit_model_load_started( "queue_item_id": queue_item_id, "queue_batch_id": queue_batch_id, "graph_execution_state_id": graph_execution_state_id, - "model_name": model_name, - "base_model": base_model, - "model_type": model_type, - "submodel": submodel, + "model_config": model_config.model_dump(), }, ) @@ -197,11 +190,7 @@ def emit_model_load_completed( queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: SubModelType, - model_info: ModelInfo, + model_config: AnyModelConfig, ) -> None: """Emitted when a model is correctly loaded (returns model info)""" self.__emit_queue_event( @@ -211,13 +200,7 @@ def emit_model_load_completed( "queue_item_id": queue_item_id, "queue_batch_id": queue_batch_id, "graph_execution_state_id": graph_execution_state_id, - "model_name": model_name, - "base_model": base_model, - "model_type": model_type, - "submodel": submodel, - "hash": model_info.hash, - "location": str(model_info.location), - "precision": str(model_info.precision), + "model_config": model_config.model_dump(), }, ) diff --git a/invokeai/app/services/image_files/image_files_base.py b/invokeai/app/services/image_files/image_files_base.py index 27dd67531f4..f4036277b72 100644 --- a/invokeai/app/services/image_files/image_files_base.py +++ b/invokeai/app/services/image_files/image_files_base.py @@ -4,7 +4,7 @@ from PIL.Image import Image as PILImageType -from invokeai.app.invocations.baseinvocation import MetadataField +from invokeai.app.invocations.fields import MetadataField from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID diff --git a/invokeai/app/services/image_files/image_files_disk.py b/invokeai/app/services/image_files/image_files_disk.py index 08448216723..fb687973bad 100644 --- a/invokeai/app/services/image_files/image_files_disk.py +++ b/invokeai/app/services/image_files/image_files_disk.py @@ -7,7 +7,7 @@ from PIL.Image import Image as PILImageType from send2trash import send2trash -from invokeai.app.invocations.baseinvocation import MetadataField +from invokeai.app.invocations.fields import MetadataField from invokeai.app.services.invoker import Invoker from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail diff --git a/invokeai/app/services/image_records/image_records_base.py b/invokeai/app/services/image_records/image_records_base.py index 727f4977fba..7b7b261ecab 100644 --- a/invokeai/app/services/image_records/image_records_base.py +++ b/invokeai/app/services/image_records/image_records_base.py @@ -2,7 +2,7 @@ from datetime import datetime from typing import Optional -from invokeai.app.invocations.metadata import MetadataField +from invokeai.app.invocations.fields import MetadataField from invokeai.app.services.shared.pagination import OffsetPaginatedResults from .image_records_common import ImageCategory, ImageRecord, ImageRecordChanges, ResourceOrigin diff --git a/invokeai/app/services/image_records/image_records_sqlite.py b/invokeai/app/services/image_records/image_records_sqlite.py index 74f82e7d84c..5b37913c8fd 100644 --- a/invokeai/app/services/image_records/image_records_sqlite.py +++ b/invokeai/app/services/image_records/image_records_sqlite.py @@ -3,7 +3,7 @@ from datetime import datetime from typing import Optional, Union, cast -from invokeai.app.invocations.baseinvocation import MetadataField, MetadataFieldValidator +from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase diff --git a/invokeai/app/services/images/images_base.py b/invokeai/app/services/images/images_base.py index df71dadb5b0..42c42667744 100644 --- a/invokeai/app/services/images/images_base.py +++ b/invokeai/app/services/images/images_base.py @@ -3,7 +3,7 @@ from PIL.Image import Image as PILImageType -from invokeai.app.invocations.baseinvocation import MetadataField +from invokeai.app.invocations.fields import MetadataField from invokeai.app.services.image_records.image_records_common import ( ImageCategory, ImageRecord, diff --git a/invokeai/app/services/images/images_default.py b/invokeai/app/services/images/images_default.py index ff21731a506..adeed738119 100644 --- a/invokeai/app/services/images/images_default.py +++ b/invokeai/app/services/images/images_default.py @@ -2,7 +2,7 @@ from PIL.Image import Image as PILImageType -from invokeai.app.invocations.baseinvocation import MetadataField +from invokeai.app.invocations.fields import MetadataField from invokeai.app.services.invoker import Invoker from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID diff --git a/invokeai/app/services/invocation_cache/invocation_cache_memory.py b/invokeai/app/services/invocation_cache/invocation_cache_memory.py index 4a503b3c6b1..c700f81186f 100644 --- a/invokeai/app/services/invocation_cache/invocation_cache_memory.py +++ b/invokeai/app/services/invocation_cache/invocation_cache_memory.py @@ -37,7 +37,8 @@ def start(self, invoker: Invoker) -> None: if self._max_cache_size == 0: return self._invoker.services.images.on_deleted(self._delete_by_match) - self._invoker.services.latents.on_deleted(self._delete_by_match) + self._invoker.services.tensors.on_deleted(self._delete_by_match) + self._invoker.services.conditioning.on_deleted(self._delete_by_match) def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]: with self._lock: diff --git a/invokeai/app/services/invocation_processor/invocation_processor_default.py b/invokeai/app/services/invocation_processor/invocation_processor_default.py index 54342c0da13..d2ebe235e63 100644 --- a/invokeai/app/services/invocation_processor/invocation_processor_default.py +++ b/invokeai/app/services/invocation_processor/invocation_processor_default.py @@ -5,11 +5,11 @@ from typing import Optional import invokeai.backend.util.logging as logger -from invokeai.app.invocations.baseinvocation import InvocationContext from invokeai.app.services.invocation_queue.invocation_queue_common import InvocationQueueItem from invokeai.app.services.invocation_stats.invocation_stats_common import ( GESStatsNotFoundError, ) +from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context from invokeai.app.util.profiler import Profiler from ..invoker import Invoker @@ -131,16 +131,20 @@ def stats_cleanup(graph_execution_state_id: str) -> None: # which handles a few things: # - nodes that require a value, but get it only from a connection # - referencing the invocation cache instead of executing the node - outputs = invocation.invoke_internal( - InvocationContext( - services=self.__invoker.services, - graph_execution_state_id=graph_execution_state.id, - queue_item_id=queue_item.session_queue_item_id, - queue_id=queue_item.session_queue_id, - queue_batch_id=queue_item.session_queue_batch_id, - workflow=queue_item.workflow, - ) + context_data = InvocationContextData( + invocation=invocation, + session_id=graph_id, + workflow=queue_item.workflow, + source_node_id=source_node_id, + queue_id=queue_item.session_queue_id, + queue_item_id=queue_item.session_queue_item_id, + batch_id=queue_item.session_queue_batch_id, + ) + context = build_invocation_context( + services=self.__invoker.services, + context_data=context_data, ) + outputs = invocation.invoke_internal(context=context, services=self.__invoker.services) # Check queue to see if this is canceled, and skip if so if self.__invoker.services.queue.is_canceled(graph_execution_state.id): diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 11a4de99d6e..0a1fa1e9222 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -3,9 +3,15 @@ from typing import TYPE_CHECKING +from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase + if TYPE_CHECKING: from logging import Logger + import torch + + from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData + from .board_image_records.board_image_records_base import BoardImageRecordStorageBase from .board_images.board_images_base import BoardImagesServiceABC from .board_records.board_records_base import BoardRecordStorageBase @@ -21,10 +27,7 @@ from .invocation_queue.invocation_queue_base import InvocationQueueABC from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase from .item_storage.item_storage_base import ItemStorageABC - from .latents_storage.latents_storage_base import LatentsStorageBase - from .model_install import ModelInstallServiceBase from .model_manager.model_manager_base import ModelManagerServiceBase - from .model_records import ModelRecordServiceBase from .names.names_base import NameServiceBase from .session_processor.session_processor_base import SessionProcessorBase from .session_queue.session_queue_base import SessionQueueBase @@ -36,33 +39,6 @@ class InvocationServices: """Services that can be used by invocations""" - # TODO: Just forward-declared everything due to circular dependencies. Fix structure. - board_images: "BoardImagesServiceABC" - board_image_record_storage: "BoardImageRecordStorageBase" - boards: "BoardServiceABC" - board_records: "BoardRecordStorageBase" - configuration: "InvokeAIAppConfig" - events: "EventServiceBase" - graph_execution_manager: "ItemStorageABC[GraphExecutionState]" - images: "ImageServiceABC" - image_records: "ImageRecordStorageBase" - image_files: "ImageFileStorageBase" - latents: "LatentsStorageBase" - logger: "Logger" - model_manager: "ModelManagerServiceBase" - model_records: "ModelRecordServiceBase" - download_queue: "DownloadQueueServiceBase" - model_install: "ModelInstallServiceBase" - processor: "InvocationProcessorABC" - performance_statistics: "InvocationStatsServiceBase" - queue: "InvocationQueueABC" - session_queue: "SessionQueueBase" - session_processor: "SessionProcessorBase" - invocation_cache: "InvocationCacheBase" - names: "NameServiceBase" - urls: "UrlServiceBase" - workflow_records: "WorkflowRecordsStorageBase" - def __init__( self, board_images: "BoardImagesServiceABC", @@ -75,12 +51,9 @@ def __init__( images: "ImageServiceABC", image_files: "ImageFileStorageBase", image_records: "ImageRecordStorageBase", - latents: "LatentsStorageBase", logger: "Logger", model_manager: "ModelManagerServiceBase", - model_records: "ModelRecordServiceBase", download_queue: "DownloadQueueServiceBase", - model_install: "ModelInstallServiceBase", processor: "InvocationProcessorABC", performance_statistics: "InvocationStatsServiceBase", queue: "InvocationQueueABC", @@ -90,6 +63,8 @@ def __init__( names: "NameServiceBase", urls: "UrlServiceBase", workflow_records: "WorkflowRecordsStorageBase", + tensors: "ObjectSerializerBase[torch.Tensor]", + conditioning: "ObjectSerializerBase[ConditioningFieldData]", ): self.board_images = board_images self.board_image_records = board_image_records @@ -101,12 +76,9 @@ def __init__( self.images = images self.image_files = image_files self.image_records = image_records - self.latents = latents self.logger = logger self.model_manager = model_manager - self.model_records = model_records self.download_queue = download_queue - self.model_install = model_install self.processor = processor self.performance_statistics = performance_statistics self.queue = queue @@ -116,3 +88,5 @@ def __init__( self.names = names self.urls = urls self.workflow_records = workflow_records + self.tensors = tensors + self.conditioning = conditioning diff --git a/invokeai/app/services/invocation_stats/invocation_stats_default.py b/invokeai/app/services/invocation_stats/invocation_stats_default.py index be58aaad2dd..6c893021de4 100644 --- a/invokeai/app/services/invocation_stats/invocation_stats_default.py +++ b/invokeai/app/services/invocation_stats/invocation_stats_default.py @@ -2,6 +2,7 @@ import time from contextlib import contextmanager from pathlib import Path +from typing import Iterator import psutil import torch @@ -10,7 +11,7 @@ from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.services.invoker import Invoker from invokeai.app.services.item_storage.item_storage_common import ItemNotFoundError -from invokeai.backend.model_management.model_cache import CacheStats +from invokeai.backend.model_manager.load.model_cache import CacheStats from .invocation_stats_base import InvocationStatsServiceBase from .invocation_stats_common import ( @@ -41,7 +42,12 @@ def start(self, invoker: Invoker) -> None: self._invoker = invoker @contextmanager - def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str): + def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str) -> Iterator[None]: + # This is to handle case of the model manager not being initialized, which happens + # during some tests. + services = self._invoker.services + if services.model_manager is None or services.model_manager.load is None: + yield None if not self._stats.get(graph_execution_state_id): # First time we're seeing this graph_execution_state_id. self._stats[graph_execution_state_id] = GraphExecutionStats() @@ -55,8 +61,9 @@ def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: st start_ram = psutil.Process().memory_info().rss if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats() - if self._invoker.services.model_manager: - self._invoker.services.model_manager.collect_cache_stats(self._cache_stats[graph_execution_state_id]) + + assert services.model_manager.load is not None + services.model_manager.load.ram_cache.stats = self._cache_stats[graph_execution_state_id] try: # Let the invocation run. @@ -73,7 +80,7 @@ def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: st ) self._stats[graph_execution_state_id].add_node_execution_stats(node_stats) - def _prune_stale_stats(self): + def _prune_stale_stats(self) -> None: """Check all graphs being tracked and prune any that have completed/errored. This shouldn't be necessary, but we don't have totally robust upstream handling of graph completions/errors, so diff --git a/invokeai/app/services/item_storage/item_storage_base.py b/invokeai/app/services/item_storage/item_storage_base.py index c93edf5188d..ef227ba241c 100644 --- a/invokeai/app/services/item_storage/item_storage_base.py +++ b/invokeai/app/services/item_storage/item_storage_base.py @@ -30,7 +30,7 @@ def get(self, item_id: str) -> T: @abstractmethod def set(self, item: T) -> None: """ - Sets the item. The id will be extracted based on id_field. + Sets the item. :param item: the item to set """ pass diff --git a/invokeai/app/services/latents_storage/latents_storage_base.py b/invokeai/app/services/latents_storage/latents_storage_base.py deleted file mode 100644 index 9fa42b0ae61..00000000000 --- a/invokeai/app/services/latents_storage/latents_storage_base.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) - -from abc import ABC, abstractmethod -from typing import Callable - -import torch - - -class LatentsStorageBase(ABC): - """Responsible for storing and retrieving latents.""" - - _on_changed_callbacks: list[Callable[[torch.Tensor], None]] - _on_deleted_callbacks: list[Callable[[str], None]] - - def __init__(self) -> None: - self._on_changed_callbacks = [] - self._on_deleted_callbacks = [] - - @abstractmethod - def get(self, name: str) -> torch.Tensor: - pass - - @abstractmethod - def save(self, name: str, data: torch.Tensor) -> None: - pass - - @abstractmethod - def delete(self, name: str) -> None: - pass - - def on_changed(self, on_changed: Callable[[torch.Tensor], None]) -> None: - """Register a callback for when an item is changed""" - self._on_changed_callbacks.append(on_changed) - - def on_deleted(self, on_deleted: Callable[[str], None]) -> None: - """Register a callback for when an item is deleted""" - self._on_deleted_callbacks.append(on_deleted) - - def _on_changed(self, item: torch.Tensor) -> None: - for callback in self._on_changed_callbacks: - callback(item) - - def _on_deleted(self, item_id: str) -> None: - for callback in self._on_deleted_callbacks: - callback(item_id) diff --git a/invokeai/app/services/latents_storage/latents_storage_disk.py b/invokeai/app/services/latents_storage/latents_storage_disk.py deleted file mode 100644 index 9192b9147f7..00000000000 --- a/invokeai/app/services/latents_storage/latents_storage_disk.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) - -from pathlib import Path -from typing import Union - -import torch - -from invokeai.app.services.invoker import Invoker - -from .latents_storage_base import LatentsStorageBase - - -class DiskLatentsStorage(LatentsStorageBase): - """Stores latents in a folder on disk without caching""" - - __output_folder: Path - - def __init__(self, output_folder: Union[str, Path]): - self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder) - self.__output_folder.mkdir(parents=True, exist_ok=True) - - def start(self, invoker: Invoker) -> None: - self._invoker = invoker - self._delete_all_latents() - - def get(self, name: str) -> torch.Tensor: - latent_path = self.get_path(name) - return torch.load(latent_path) - - def save(self, name: str, data: torch.Tensor) -> None: - self.__output_folder.mkdir(parents=True, exist_ok=True) - latent_path = self.get_path(name) - torch.save(data, latent_path) - - def delete(self, name: str) -> None: - latent_path = self.get_path(name) - latent_path.unlink() - - def get_path(self, name: str) -> Path: - return self.__output_folder / name - - def _delete_all_latents(self) -> None: - """ - Deletes all latents from disk. - Must be called after we have access to `self._invoker` (e.g. in `start()`). - """ - deleted_latents_count = 0 - freed_space = 0 - for latents_file in Path(self.__output_folder).glob("*"): - if latents_file.is_file(): - freed_space += latents_file.stat().st_size - deleted_latents_count += 1 - latents_file.unlink() - if deleted_latents_count > 0: - freed_space_in_mb = round(freed_space / 1024 / 1024, 2) - self._invoker.services.logger.info( - f"Deleted {deleted_latents_count} latents files (freed {freed_space_in_mb}MB)" - ) diff --git a/invokeai/app/services/latents_storage/latents_storage_forward_cache.py b/invokeai/app/services/latents_storage/latents_storage_forward_cache.py deleted file mode 100644 index 6232b76a27d..00000000000 --- a/invokeai/app/services/latents_storage/latents_storage_forward_cache.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) - -from queue import Queue -from typing import Dict, Optional - -import torch - -from invokeai.app.services.invoker import Invoker - -from .latents_storage_base import LatentsStorageBase - - -class ForwardCacheLatentsStorage(LatentsStorageBase): - """Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage""" - - __cache: Dict[str, torch.Tensor] - __cache_ids: Queue - __max_cache_size: int - __underlying_storage: LatentsStorageBase - - def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20): - super().__init__() - self.__underlying_storage = underlying_storage - self.__cache = {} - self.__cache_ids = Queue() - self.__max_cache_size = max_cache_size - - def start(self, invoker: Invoker) -> None: - self._invoker = invoker - start_op = getattr(self.__underlying_storage, "start", None) - if callable(start_op): - start_op(invoker) - - def stop(self, invoker: Invoker) -> None: - self._invoker = invoker - stop_op = getattr(self.__underlying_storage, "stop", None) - if callable(stop_op): - stop_op(invoker) - - def get(self, name: str) -> torch.Tensor: - cache_item = self.__get_cache(name) - if cache_item is not None: - return cache_item - - latent = self.__underlying_storage.get(name) - self.__set_cache(name, latent) - return latent - - def save(self, name: str, data: torch.Tensor) -> None: - self.__underlying_storage.save(name, data) - self.__set_cache(name, data) - self._on_changed(data) - - def delete(self, name: str) -> None: - self.__underlying_storage.delete(name) - if name in self.__cache: - del self.__cache[name] - self._on_deleted(name) - - def __get_cache(self, name: str) -> Optional[torch.Tensor]: - return None if name not in self.__cache else self.__cache[name] - - def __set_cache(self, name: str, data: torch.Tensor): - if name not in self.__cache: - self.__cache[name] = data - self.__cache_ids.put(name) - if self.__cache_ids.qsize() > self.__max_cache_size: - self.__cache.pop(self.__cache_ids.get()) diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index 635cb154d64..39ea8c4a0d1 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -127,8 +127,8 @@ def proper_repo_id(cls, v: str) -> str: # noqa D102 def __str__(self) -> str: """Return string version of repoid when string rep needed.""" base: str = self.repo_id + base += f":{self.variant or ''}" base += f":{self.subfolder}" if self.subfolder else "" - base += f" ({self.variant})" if self.variant else "" return base @@ -324,6 +324,43 @@ def install_path( :returns id: The string ID of the registered model. """ + @abstractmethod + def heuristic_import( + self, + source: str, + config: Optional[Dict[str, Any]] = None, + access_token: Optional[str] = None, + ) -> ModelInstallJob: + r"""Install the indicated model using heuristics to interpret user intentions. + + :param source: String source + :param config: Optional dict. Any fields in this dict + will override corresponding autoassigned probe fields in the + model's config record as described in `import_model()`. + :param access_token: Optional access token for remote sources. + + The source can be: + 1. A local file path in posix() format (`/foo/bar` or `C:\foo\bar`) + 2. An http or https URL (`https://foo.bar/foo`) + 3. A HuggingFace repo_id (`foo/bar`, `foo/bar:fp16`, `foo/bar:fp16:vae`) + + We extend the HuggingFace repo_id syntax to include the variant and the + subfolder or path. The following are acceptable alternatives: + stabilityai/stable-diffusion-v4 + stabilityai/stable-diffusion-v4:fp16 + stabilityai/stable-diffusion-v4:fp16:vae + stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors + stabilityai/stable-diffusion-v4:onnx:vae + + Because a local file path can look like a huggingface repo_id, the logic + first checks whether the path exists on disk, and if not, it is treated as + a parseable huggingface repo. + + The previous support for recursing into a local folder and loading all model-like files + has been removed. + """ + pass + @abstractmethod def import_model( self, @@ -385,6 +422,18 @@ def prune_jobs(self) -> None: def cancel_job(self, job: ModelInstallJob) -> None: """Cancel the indicated job.""" + @abstractmethod + def wait_for_job(self, job: ModelInstallJob, timeout: int = 0) -> ModelInstallJob: + """Wait for the indicated job to reach a terminal state. + + This will block until the indicated install job has completed, + been cancelled, or errored out. + + :param job: The job to wait on. + :param timeout: Wait up to indicated number of seconds. Raise a TimeoutError if + the job hasn't completed within the indicated time. + """ + @abstractmethod def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: """ @@ -394,7 +443,8 @@ def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: completed, been cancelled, or errored out. :param timeout: Wait up to indicated number of seconds. Raise an Exception('timeout') if - installs do not complete within the indicated time. + installs do not complete within the indicated time. A timeout of zero (the default) + will block indefinitely until the installs complete. """ @abstractmethod @@ -410,3 +460,22 @@ def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: @abstractmethod def sync_to_config(self) -> None: """Synchronize models on disk to those in the model record database.""" + + @abstractmethod + def download_and_cache(self, source: Union[str, AnyHttpUrl], access_token: Optional[str] = None) -> Path: + """ + Download the model file located at source to the models cache and return its Path. + + :param source: A Url or a string that can be converted into one. + :param access_token: Optional access token to access restricted resources. + + The model file will be downloaded into the system-wide model cache + (`models/.cache`) if it isn't already there. Note that the model cache + is periodically cleared of infrequently-used entries when the model + converter runs. + + Note that this doesn't automaticallly install or register the model, but is + intended for use by nodes that need access to models that aren't directly + supported by InvokeAI. The downloading process takes advantage of the download queue + to avoid interrupting other operations. + """ diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 82c667f584f..20a85a82a14 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -17,7 +17,7 @@ from requests import Session from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase +from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invoker import Invoker from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase, ModelRecordServiceSQL @@ -50,6 +50,7 @@ ModelInstallJob, ModelInstallServiceBase, ModelSource, + StringLikeSource, URLModelSource, ) @@ -86,6 +87,7 @@ def __init__( self._lock = threading.Lock() self._stop_event = threading.Event() self._downloads_changed_event = threading.Event() + self._install_completed_event = threading.Event() self._download_queue = download_queue self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {} self._running = False @@ -145,7 +147,7 @@ def register_path( ) -> str: # noqa D102 model_path = Path(model_path) config = config or {} - if config.get("source") is None: + if not config.get("source"): config["source"] = model_path.resolve().as_posix() return self._register(model_path, config) @@ -156,12 +158,14 @@ def install_path( ) -> str: # noqa D102 model_path = Path(model_path) config = config or {} - if config.get("source") is None: + if not config.get("source"): config["source"] = model_path.resolve().as_posix() info: AnyModelConfig = self._probe_model(Path(model_path), config) - old_hash = info.original_hash - dest_path = self.app_config.models_path / info.base.value / info.type.value / model_path.name + old_hash = info.current_hash + dest_path = ( + self.app_config.models_path / info.base.value / info.type.value / (config.get("name") or model_path.name) + ) try: new_path = self._copy_model(model_path, dest_path) except FileExistsError as excp: @@ -177,7 +181,40 @@ def install_path( info, ) + def heuristic_import( + self, + source: str, + config: Optional[Dict[str, Any]] = None, + access_token: Optional[str] = None, + ) -> ModelInstallJob: + variants = "|".join(ModelRepoVariant.__members__.values()) + hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$" + source_obj: Optional[StringLikeSource] = None + + if Path(source).exists(): # A local file or directory + source_obj = LocalModelSource(path=Path(source)) + elif match := re.match(hf_repoid_re, source): + source_obj = HFModelSource( + repo_id=match.group(1), + variant=match.group(2) if match.group(2) else None, # pass None rather than '' + subfolder=Path(match.group(3)) if match.group(3) else None, + access_token=access_token, + ) + elif re.match(r"^https?://[^/]+", source): + source_obj = URLModelSource( + url=AnyHttpUrl(source), + access_token=access_token, + ) + else: + raise ValueError(f"Unsupported model source: '{source}'") + return self.import_model(source_obj, config) + def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102 + similar_jobs = [x for x in self.list_jobs() if x.source == source and not x.in_terminal_state] + if similar_jobs: + self._logger.warning(f"There is already an active install job for {source}. Not enqueuing.") + return similar_jobs[0] + if isinstance(source, LocalModelSource): install_job = self._import_local_model(source, config) self._install_queue.put(install_job) # synchronously install @@ -207,6 +244,17 @@ def get_job_by_id(self, id: int) -> ModelInstallJob: # noqa D102 assert isinstance(jobs[0], ModelInstallJob) return jobs[0] + def wait_for_job(self, job: ModelInstallJob, timeout: int = 0) -> ModelInstallJob: + """Block until the indicated job has reached terminal state, or when timeout limit reached.""" + start = time.time() + while not job.in_terminal_state: + if self._install_completed_event.wait(timeout=5): # in case we miss an event + self._install_completed_event.clear() + if timeout > 0 and time.time() - start > timeout: + raise TimeoutError("Timeout exceeded") + return job + + # TODO: Better name? Maybe wait_for_jobs()? Maybe too easily confused with above def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: # noqa D102 """Block until all installation jobs are done.""" start = time.time() @@ -214,7 +262,7 @@ def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: # noqa if self._downloads_changed_event.wait(timeout=5): # in case we miss an event self._downloads_changed_event.clear() if timeout > 0 and time.time() - start > timeout: - raise Exception("Timeout exceeded") + raise TimeoutError("Timeout exceeded") self._install_queue.join() return self._install_jobs @@ -268,6 +316,38 @@ def unconditionally_delete(self, key: str) -> None: # noqa D102 path.unlink() self.unregister(key) + def download_and_cache( + self, + source: Union[str, AnyHttpUrl], + access_token: Optional[str] = None, + timeout: int = 0, + ) -> Path: + """Download the model file located at source to the models cache and return its Path.""" + model_hash = sha256(str(source).encode("utf-8")).hexdigest()[0:32] + model_path = self._app_config.models_convert_cache_path / model_hash + + # We expect the cache directory to contain one and only one downloaded file. + # We don't know the file's name in advance, as it is set by the download + # content-disposition header. + if model_path.exists(): + contents = [x for x in model_path.iterdir() if x.is_file()] + if len(contents) > 0: + return contents[0] + + model_path.mkdir(parents=True, exist_ok=True) + job = self._download_queue.download( + source=AnyHttpUrl(str(source)), + dest=model_path, + access_token=access_token, + on_progress=TqdmProgress().update, + ) + self._download_queue.wait_for_job(job, timeout) + if job.complete: + assert job.download_path is not None + return job.download_path + else: + raise Exception(job.error) + # -------------------------------------------------------------------------------------------- # Internal functions that manage the installer threads # -------------------------------------------------------------------------------------------- @@ -300,6 +380,7 @@ def _install_next_item(self) -> None: job.total_bytes = self._stat_size(job.local_path) job.bytes = job.total_bytes self._signal_job_running(job) + job.config_in["source"] = str(job.source) if job.inplace: key = self.register_path(job.local_path, job.config_in) else: @@ -330,6 +411,7 @@ def _install_next_item(self) -> None: # if this is an install of a remote file, then clean up the temporary directory if job._install_tmpdir is not None: rmtree(job._install_tmpdir) + self._install_completed_event.set() self._install_queue.task_done() self._logger.info("Install thread exiting") @@ -489,10 +571,10 @@ def _next_id(self) -> int: return id @staticmethod - def _guess_variant() -> ModelRepoVariant: + def _guess_variant() -> Optional[ModelRepoVariant]: """Guess the best HuggingFace variant type to download.""" precision = choose_precision(choose_torch_device()) - return ModelRepoVariant.FP16 if precision == "float16" else ModelRepoVariant.DEFAULT + return ModelRepoVariant.FP16 if precision == "float16" else None def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob: return ModelInstallJob( @@ -517,7 +599,7 @@ def _import_from_hf(self, source: HFModelSource, config: Optional[Dict[str, Any] if not source.access_token: self._logger.info("No HuggingFace access token present; some models may not be downloadable.") - metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id) + metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant) assert isinstance(metadata, ModelMetadataWithFiles) remote_files = metadata.download_urls( variant=source.variant or self._guess_variant(), @@ -565,6 +647,8 @@ def _import_remote_model( # TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up. # Currently the tmpdir isn't automatically removed at exit because it is # being held in a daemon thread. + if len(remote_files) == 0: + raise ValueError(f"{source}: No downloadable files found") tmpdir = Path( mkdtemp( dir=self._app_config.models_path, @@ -580,6 +664,16 @@ def _import_remote_model( bytes=0, total_bytes=0, ) + # In the event that there is a subfolder specified in the source, + # we need to remove it from the destination path in order to avoid + # creating unwanted subfolders + if hasattr(source, "subfolder") and source.subfolder: + root = Path(remote_files[0].path.parts[0]) + subfolder = root / source.subfolder + else: + root = Path(".") + subfolder = Path(".") + # we remember the path up to the top of the tmpdir so that it may be # removed safely at the end of the install process. install_job._install_tmpdir = tmpdir @@ -589,7 +683,7 @@ def _import_remote_model( self._logger.debug(f"remote_files={remote_files}") for model_file in remote_files: url = model_file.url - path = model_file.path + path = root / model_file.path.relative_to(subfolder) self._logger.info(f"Downloading {url} => {path}") install_job.total_bytes += model_file.size assert hasattr(source, "access_token") diff --git a/invokeai/app/services/model_load/__init__.py b/invokeai/app/services/model_load/__init__.py new file mode 100644 index 00000000000..b4a86e9348d --- /dev/null +++ b/invokeai/app/services/model_load/__init__.py @@ -0,0 +1,6 @@ +"""Initialization file for model load service module.""" + +from .model_load_base import ModelLoadServiceBase +from .model_load_default import ModelLoadService + +__all__ = ["ModelLoadServiceBase", "ModelLoadService"] diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py new file mode 100644 index 00000000000..f4dd905135a --- /dev/null +++ b/invokeai/app/services/model_load/model_load_base.py @@ -0,0 +1,84 @@ +# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team +"""Base class for model loader.""" + +from abc import ABC, abstractmethod +from typing import Optional + +from invokeai.app.services.shared.invocation_context import InvocationContextData +from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType +from invokeai.backend.model_manager.load import LoadedModel +from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase +from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase + + +class ModelLoadServiceBase(ABC): + """Wrapper around AnyModelLoader.""" + + @abstractmethod + def load_model_by_key( + self, + key: str, + submodel_type: Optional[SubModelType] = None, + context_data: Optional[InvocationContextData] = None, + ) -> LoadedModel: + """ + Given a model's key, load it and return the LoadedModel object. + + :param key: Key of model config to be fetched. + :param submodel: For main (pipeline models), the submodel to fetch. + :param context_data: Invocation context data used for event reporting + """ + pass + + @abstractmethod + def load_model_by_config( + self, + model_config: AnyModelConfig, + submodel_type: Optional[SubModelType] = None, + context_data: Optional[InvocationContextData] = None, + ) -> LoadedModel: + """ + Given a model's configuration, load it and return the LoadedModel object. + + :param model_config: Model configuration record (as returned by ModelRecordBase.get_model()) + :param submodel: For main (pipeline models), the submodel to fetch. + :param context_data: Invocation context data used for event reporting + """ + pass + + @abstractmethod + def load_model_by_attr( + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + submodel: Optional[SubModelType] = None, + context_data: Optional[InvocationContextData] = None, + ) -> LoadedModel: + """ + Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object. + + This is provided for API compatability with the get_model() method + in the original model manager. However, note that LoadedModel is + not the same as the original ModelInfo that ws returned. + + :param model_name: Name of to be fetched. + :param base_model: Base model + :param model_type: Type of the model + :param submodel: For main (pipeline models), the submodel to fetch + :param context_data: The invocation context data. + + Exceptions: UnknownModelException -- model with these attributes not known + NotImplementedException -- a model loader was not provided at initialization time + ValueError -- more than one model matches this combination + """ + + @property + @abstractmethod + def ram_cache(self) -> ModelCacheBase[AnyModel]: + """Return the RAM cache used by this loader.""" + + @property + @abstractmethod + def convert_cache(self) -> ModelConvertCacheBase: + """Return the checkpoint convert cache used by this loader.""" diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py new file mode 100644 index 00000000000..29b297c8145 --- /dev/null +++ b/invokeai/app/services/model_load/model_load_default.py @@ -0,0 +1,165 @@ +# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team +"""Implementation of model loader service.""" + +from typing import Optional + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.model_records import ModelRecordServiceBase, UnknownModelException +from invokeai.app.services.shared.invocation_context import InvocationContextData +from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType +from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel, ModelCache, ModelConvertCache +from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase +from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase +from invokeai.backend.util.logging import InvokeAILogger + +from .model_load_base import ModelLoadServiceBase + + +class ModelLoadService(ModelLoadServiceBase): + """Wrapper around AnyModelLoader.""" + + def __init__( + self, + app_config: InvokeAIAppConfig, + record_store: ModelRecordServiceBase, + ram_cache: Optional[ModelCacheBase[AnyModel]] = None, + convert_cache: Optional[ModelConvertCacheBase] = None, + ): + """Initialize the model load service.""" + logger = InvokeAILogger.get_logger(self.__class__.__name__) + logger.setLevel(app_config.log_level.upper()) + self._store = record_store + self._any_loader = AnyModelLoader( + app_config=app_config, + logger=logger, + ram_cache=ram_cache + or ModelCache( + max_cache_size=app_config.ram_cache_size, + max_vram_cache_size=app_config.vram_cache_size, + logger=logger, + ), + convert_cache=convert_cache + or ModelConvertCache( + cache_path=app_config.models_convert_cache_path, + max_size=app_config.convert_cache_size, + ), + ) + + def start(self, invoker: Invoker) -> None: + self._invoker = invoker + + @property + def ram_cache(self) -> ModelCacheBase[AnyModel]: + """Return the RAM cache used by this loader.""" + return self._any_loader.ram_cache + + @property + def convert_cache(self) -> ModelConvertCacheBase: + """Return the checkpoint convert cache used by this loader.""" + return self._any_loader.convert_cache + + def load_model_by_key( + self, + key: str, + submodel_type: Optional[SubModelType] = None, + context_data: Optional[InvocationContextData] = None, + ) -> LoadedModel: + """ + Given a model's key, load it and return the LoadedModel object. + + :param key: Key of model config to be fetched. + :param submodel: For main (pipeline models), the submodel to fetch. + :param context: Invocation context used for event reporting + """ + config = self._store.get_model(key) + return self.load_model_by_config(config, submodel_type, context_data) + + def load_model_by_attr( + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + submodel: Optional[SubModelType] = None, + context_data: Optional[InvocationContextData] = None, + ) -> LoadedModel: + """ + Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object. + + This is provided for API compatability with the get_model() method + in the original model manager. However, note that LoadedModel is + not the same as the original ModelInfo that ws returned. + + :param model_name: Name of to be fetched. + :param base_model: Base model + :param model_type: Type of the model + :param submodel: For main (pipeline models), the submodel to fetch + :param context: The invocation context. + + Exceptions: UnknownModelException -- model with this key not known + NotImplementedException -- a model loader was not provided at initialization time + ValueError -- more than one model matches this combination + """ + configs = self._store.search_by_attr(model_name, base_model, model_type) + if len(configs) == 0: + raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model") + elif len(configs) > 1: + raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.") + else: + return self.load_model_by_key(configs[0].key, submodel) + + def load_model_by_config( + self, + model_config: AnyModelConfig, + submodel_type: Optional[SubModelType] = None, + context_data: Optional[InvocationContextData] = None, + ) -> LoadedModel: + """ + Given a model's configuration, load it and return the LoadedModel object. + + :param model_config: Model configuration record (as returned by ModelRecordBase.get_model()) + :param submodel: For main (pipeline models), the submodel to fetch. + :param context: Invocation context used for event reporting + """ + if context_data: + self._emit_load_event( + context_data=context_data, + model_config=model_config, + ) + loaded_model = self._any_loader.load_model(model_config, submodel_type) + if context_data: + self._emit_load_event( + context_data=context_data, + model_config=model_config, + loaded=True, + ) + return loaded_model + + def _emit_load_event( + self, + context_data: InvocationContextData, + model_config: AnyModelConfig, + loaded: Optional[bool] = False, + ) -> None: + if not self._invoker: + return + if self._invoker.services.queue.is_canceled(context_data.session_id): + raise CanceledException() + + if not loaded: + self._invoker.services.events.emit_model_load_started( + queue_id=context_data.queue_id, + queue_item_id=context_data.queue_item_id, + queue_batch_id=context_data.batch_id, + graph_execution_state_id=context_data.session_id, + model_config=model_config, + ) + else: + self._invoker.services.events.emit_model_load_completed( + queue_id=context_data.queue_id, + queue_item_id=context_data.queue_item_id, + queue_batch_id=context_data.batch_id, + graph_execution_state_id=context_data.session_id, + model_config=model_config, + ) diff --git a/invokeai/app/services/model_manager/__init__.py b/invokeai/app/services/model_manager/__init__.py index 3d6a9c248c6..5e281922a8b 100644 --- a/invokeai/app/services/model_manager/__init__.py +++ b/invokeai/app/services/model_manager/__init__.py @@ -1 +1,16 @@ -from .model_manager_default import ModelManagerService # noqa F401 +"""Initialization file for model manager service.""" + +from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType +from invokeai.backend.model_manager.load import LoadedModel + +from .model_manager_default import ModelManagerService + +__all__ = [ + "ModelManagerService", + "AnyModel", + "AnyModelConfig", + "BaseModelType", + "ModelType", + "SubModelType", + "LoadedModel", +] diff --git a/invokeai/app/services/model_manager/model_manager_base.py b/invokeai/app/services/model_manager/model_manager_base.py index 4c2fc4c085c..1116c82ff1f 100644 --- a/invokeai/app/services/model_manager/model_manager_base.py +++ b/invokeai/app/services/model_manager/model_manager_base.py @@ -1,286 +1,67 @@ # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team -from __future__ import annotations - from abc import ABC, abstractmethod -from logging import Logger -from pathlib import Path -from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union -from pydantic import Field +from typing_extensions import Self -from invokeai.app.services.config.config_default import InvokeAIAppConfig -from invokeai.backend.model_management import ( - AddModelResult, - BaseModelType, - MergeInterpolationMethod, - ModelInfo, - ModelType, - SchedulerPredictionType, - SubModelType, -) -from invokeai.backend.model_management.model_cache import CacheStats +from invokeai.app.services.invoker import Invoker -if TYPE_CHECKING: - from invokeai.app.invocations.baseinvocation import BaseInvocation, InvocationContext +from ..config import InvokeAIAppConfig +from ..download import DownloadQueueServiceBase +from ..events.events_base import EventServiceBase +from ..model_install import ModelInstallServiceBase +from ..model_load import ModelLoadServiceBase +from ..model_records import ModelRecordServiceBase +from ..shared.sqlite.sqlite_database import SqliteDatabase class ModelManagerServiceBase(ABC): - """Responsible for managing models on disk and in memory""" - - @abstractmethod - def __init__( - self, - config: InvokeAIAppConfig, - logger: Logger, - ): - """ - Initialize with the path to the models.yaml config file. - Optional parameters are the torch device type, precision, max_models, - and sequential_offload boolean. Note that the default device - type and precision are set up for a CUDA system running at half precision. - """ - pass - - @abstractmethod - def get_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: Optional[SubModelType] = None, - node: Optional[BaseInvocation] = None, - context: Optional[InvocationContext] = None, - ) -> ModelInfo: - """Retrieve the indicated model with name and type. - submodel can be used to get a part (such as the vae) - of a diffusers pipeline.""" - pass - - @property - @abstractmethod - def logger(self): - pass - - @abstractmethod - def model_exists( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - ) -> bool: - pass - - @abstractmethod - def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: - """ - Given a model name returns a dict-like (OmegaConf) object describing it. - Uses the exact format as the omegaconf stanza. - """ - pass - - @abstractmethod - def list_models(self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None) -> dict: - """ - Return a dict of models in the format: - { model_type1: - { model_name1: {'status': 'active'|'cached'|'not loaded', - 'model_name' : name, - 'model_type' : SDModelType, - 'description': description, - 'format': 'folder'|'safetensors'|'ckpt' - }, - model_name2: { etc } - }, - model_type2: - { model_name_n: etc - } - """ - pass - - @abstractmethod - def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: - """ - Return information about the model using the same format as list_models() - """ - pass - - @abstractmethod - def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]: - """ - Returns a list of all the model names known. - """ - pass - - @abstractmethod - def add_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - model_attributes: dict, - clobber: bool = False, - ) -> AddModelResult: - """ - Update the named model with a dictionary of attributes. Will fail with an - assertion error if the name already exists. Pass clobber=True to overwrite. - On a successful update, the config will be changed in memory. Will fail - with an assertion error if provided attributes are incorrect or - the model name is missing. Call commit() to write changes to disk. - """ - pass - - @abstractmethod - def update_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - model_attributes: dict, - ) -> AddModelResult: - """ - Update the named model with a dictionary of attributes. Will fail with a - ModelNotFoundException if the name does not already exist. - - On a successful update, the config will be changed in memory. Will fail - with an assertion error if provided attributes are incorrect or - the model name is missing. Call commit() to write changes to disk. - """ - pass - - @abstractmethod - def del_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - ): - """ - Delete the named model from configuration. If delete_files is true, - then the underlying weight file or diffusers directory will be deleted - as well. Call commit() to write to disk. - """ - pass - - @abstractmethod - def rename_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - new_name: str, - ): - """ - Rename the indicated model. - """ - pass + """Abstract base class for the model manager service.""" - @abstractmethod - def list_checkpoint_configs(self) -> List[Path]: - """ - List the checkpoint config paths from ROOT/configs/stable-diffusion. - """ - pass + # attributes: + # store: ModelRecordServiceBase = Field(description="An instance of the model record configuration service.") + # install: ModelInstallServiceBase = Field(description="An instance of the model install service.") + # load: ModelLoadServiceBase = Field(description="An instance of the model load service.") + @classmethod @abstractmethod - def convert_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: Literal[ModelType.Main, ModelType.Vae], - ) -> AddModelResult: - """ - Convert a checkpoint file into a diffusers folder, deleting the cached - version and deleting the original checkpoint file if it is in the models - directory. - :param model_name: Name of the model to convert - :param base_model: Base model type - :param model_type: Type of model ['vae' or 'main'] - - This will raise a ValueError unless the model is not a checkpoint. It will - also raise a ValueError in the event that there is a similarly-named diffusers - directory already in place. + def build_model_manager( + cls, + app_config: InvokeAIAppConfig, + db: SqliteDatabase, + download_queue: DownloadQueueServiceBase, + events: EventServiceBase, + ) -> Self: """ - pass - - @abstractmethod - def heuristic_import( - self, - items_to_import: set[str], - prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None, - ) -> dict[str, AddModelResult]: - """Import a list of paths, repo_ids or URLs. Returns the set of - successfully imported items. - :param items_to_import: Set of strings corresponding to models to be imported. - :param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType. - - The prediction type helper is necessary to distinguish between - models based on Stable Diffusion 2 Base (requiring - SchedulerPredictionType.Epsilson) and Stable Diffusion 768 - (requiring SchedulerPredictionType.VPrediction). It is - generally impossible to do this programmatically, so the - prediction_type_helper usually asks the user to choose. + Construct the model manager service instance. - The result is a set of successfully installed models. Each element - of the set is a dict corresponding to the newly-created OmegaConf stanza for - that model. + Use it rather than the __init__ constructor. This class + method simplifies the construction considerably. """ pass + @property @abstractmethod - def merge_models( - self, - model_names: List[str] = Field( - default=None, min_length=2, max_length=3, description="List of model names to merge" - ), - base_model: Union[BaseModelType, str] = Field( - default=None, description="Base model shared by all models to be merged" - ), - merged_model_name: str = Field(default=None, description="Name of destination model after merging"), - alpha: Optional[float] = 0.5, - interp: Optional[MergeInterpolationMethod] = None, - force: Optional[bool] = False, - merge_dest_directory: Optional[Path] = None, - ) -> AddModelResult: - """ - Merge two to three diffusrs pipeline models and save as a new model. - :param model_names: List of 2-3 models to merge - :param base_model: Base model to use for all models - :param merged_model_name: Name of destination merged model - :param alpha: Alpha strength to apply to 2d and 3d model - :param interp: Interpolation method. None (default) - :param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended) - """ + def store(self) -> ModelRecordServiceBase: + """Return the ModelRecordServiceBase used to store and retrieve configuration records.""" pass + @property @abstractmethod - def search_for_models(self, directory: Path) -> List[Path]: - """ - Return list of all models found in the designated directory. - """ + def load(self) -> ModelLoadServiceBase: + """Return the ModelLoadServiceBase used to load models from their configuration records.""" pass + @property @abstractmethod - def sync_to_config(self): - """ - Re-read models.yaml, rescan the models directory, and reimport models - in the autoimport directories. Call after making changes outside the - model manager API. - """ + def install(self) -> ModelInstallServiceBase: + """Return the ModelInstallServiceBase used to download and manipulate model files.""" pass @abstractmethod - def collect_cache_stats(self, cache_stats: CacheStats): - """ - Reset model cache statistics for graph with graph_id. - """ + def start(self, invoker: Invoker) -> None: pass @abstractmethod - def commit(self, conf_file: Optional[Path] = None) -> None: - """ - Write current configuration out to the indicated file. - If no conf_file is provided, then replaces the - original file/database used to initialize the object. - """ + def stop(self, invoker: Invoker) -> None: pass diff --git a/invokeai/app/services/latents_storage/__init__.py b/invokeai/app/services/model_manager/model_manager_common.py similarity index 100% rename from invokeai/app/services/latents_storage/__init__.py rename to invokeai/app/services/model_manager/model_manager_common.py diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index cdb3e59a91c..028d4af6159 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -1,413 +1,100 @@ # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team - -from __future__ import annotations - -from logging import Logger -from pathlib import Path -from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union - -import torch -from pydantic import Field - -from invokeai.app.services.config.config_default import InvokeAIAppConfig -from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException -from invokeai.backend.model_management import ( - AddModelResult, - BaseModelType, - MergeInterpolationMethod, - ModelInfo, - ModelManager, - ModelMerger, - ModelNotFoundException, - ModelType, - SchedulerPredictionType, - SubModelType, -) -from invokeai.backend.model_management.model_cache import CacheStats -from invokeai.backend.model_management.model_search import FindModels -from invokeai.backend.util import choose_precision, choose_torch_device - +"""Implementation of ModelManagerServiceBase.""" + +from typing_extensions import Self + +from invokeai.app.services.invoker import Invoker +from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache +from invokeai.backend.model_manager.metadata import ModelMetadataStore +from invokeai.backend.util.logging import InvokeAILogger + +from ..config import InvokeAIAppConfig +from ..download import DownloadQueueServiceBase +from ..events.events_base import EventServiceBase +from ..model_install import ModelInstallService, ModelInstallServiceBase +from ..model_load import ModelLoadService, ModelLoadServiceBase +from ..model_records import ModelRecordServiceBase, ModelRecordServiceSQL +from ..shared.sqlite.sqlite_database import SqliteDatabase from .model_manager_base import ModelManagerServiceBase -if TYPE_CHECKING: - from invokeai.app.invocations.baseinvocation import InvocationContext - -# simple implementation class ModelManagerService(ModelManagerServiceBase): - """Responsible for managing models on disk and in memory""" - - def __init__( - self, - config: InvokeAIAppConfig, - logger: Logger, - ): - """ - Initialize with the path to the models.yaml config file. - Optional parameters are the torch device type, precision, max_models, - and sequential_offload boolean. Note that the default device - type and precision are set up for a CUDA system running at half precision. - """ - if config.model_conf_path and config.model_conf_path.exists(): - config_file = config.model_conf_path - else: - config_file = config.root_dir / "configs/models.yaml" - - logger.debug(f"Config file={config_file}") - - device = torch.device(choose_torch_device()) - device_name = torch.cuda.get_device_name() if device == torch.device("cuda") else "" - logger.info(f"GPU device = {device} {device_name}") - - precision = config.precision - if precision == "auto": - precision = choose_precision(device) - dtype = torch.float32 if precision == "float32" else torch.float16 - - # this is transitional backward compatibility - # support for the deprecated `max_loaded_models` - # configuration value. If present, then the - # cache size is set to 2.5 GB times - # the number of max_loaded_models. Otherwise - # use new `ram_cache_size` config setting - max_cache_size = config.ram_cache_size - - logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB") - - sequential_offload = config.sequential_guidance - - self.mgr = ModelManager( - config=config_file, - device_type=device, - precision=dtype, - max_cache_size=max_cache_size, - sequential_offload=sequential_offload, - logger=logger, - ) - logger.info("Model manager service initialized") - - def get_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: Optional[SubModelType] = None, - context: Optional[InvocationContext] = None, - ) -> ModelInfo: - """ - Retrieve the indicated model. submodel can be used to get a - part (such as the vae) of a diffusers mode. - """ - - # we can emit model loading events if we are executing with access to the invocation context - if context: - self._emit_load_event( - context=context, - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=submodel, - ) - - model_info = self.mgr.get_model( - model_name, - base_model, - model_type, - submodel, - ) - - if context: - self._emit_load_event( - context=context, - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=submodel, - model_info=model_info, - ) - - return model_info - - def model_exists( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - ) -> bool: - """ - Given a model name, returns True if it is a valid - identifier. - """ - return self.mgr.model_exists( - model_name, - base_model, - model_type, - ) - - def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]: - """ - Given a model name returns a dict-like (OmegaConf) object describing it. - """ - return self.mgr.model_info(model_name, base_model, model_type) - - def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]: - """ - Returns a list of all the model names known. - """ - return self.mgr.model_names() - - def list_models( - self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None - ) -> list[dict]: - """ - Return a list of models. - """ - return self.mgr.list_models(base_model, model_type) - - def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]: - """ - Return information about the model using the same format as list_models() - """ - return self.mgr.list_model(model_name=model_name, base_model=base_model, model_type=model_type) - - def add_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - model_attributes: dict, - clobber: bool = False, - ) -> AddModelResult: - """ - Update the named model with a dictionary of attributes. Will fail with an - assertion error if the name already exists. Pass clobber=True to overwrite. - On a successful update, the config will be changed in memory. Will fail - with an assertion error if provided attributes are incorrect or - the model name is missing. Call commit() to write changes to disk. - """ - self.logger.debug(f"add/update model {model_name}") - return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber) + """ + The ModelManagerService handles various aspects of model installation, maintenance and loading. - def update_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - model_attributes: dict, - ) -> AddModelResult: - """ - Update the named model with a dictionary of attributes. Will fail with a - ModelNotFoundException exception if the name does not already exist. - On a successful update, the config will be changed in memory. Will fail - with an assertion error if provided attributes are incorrect or - the model name is missing. Call commit() to write changes to disk. - """ - self.logger.debug(f"update model {model_name}") - if not self.model_exists(model_name, base_model, model_type): - raise ModelNotFoundException(f"Unknown model {model_name}") - return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True) + It bundles three distinct services: + model_manager.store -- Routines to manage the database of model configuration records. + model_manager.install -- Routines to install, move and delete models. + model_manager.load -- Routines to load models into memory. + """ - def del_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - ): - """ - Delete the named model from configuration. If delete_files is true, - then the underlying weight file or diffusers directory will be deleted - as well. - """ - self.logger.debug(f"delete model {model_name}") - self.mgr.del_model(model_name, base_model, model_type) - self.mgr.commit() - - def convert_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: Literal[ModelType.Main, ModelType.Vae], - convert_dest_directory: Optional[Path] = Field( - default=None, description="Optional directory location for merged model" - ), - ) -> AddModelResult: - """ - Convert a checkpoint file into a diffusers folder, deleting the cached - version and deleting the original checkpoint file if it is in the models - directory. - :param model_name: Name of the model to convert - :param base_model: Base model type - :param model_type: Type of model ['vae' or 'main'] - :param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default) - - This will raise a ValueError unless the model is not a checkpoint. It will - also raise a ValueError in the event that there is a similarly-named diffusers - directory already in place. - """ - self.logger.debug(f"convert model {model_name}") - return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory) - - def collect_cache_stats(self, cache_stats: CacheStats): - """ - Reset model cache statistics for graph with graph_id. - """ - self.mgr.cache.stats = cache_stats - - def commit(self, conf_file: Optional[Path] = None): - """ - Write current configuration out to the indicated file. - If no conf_file is provided, then replaces the - original file/database used to initialize the object. - """ - return self.mgr.commit(conf_file) - - def _emit_load_event( + def __init__( self, - context: InvocationContext, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: Optional[SubModelType] = None, - model_info: Optional[ModelInfo] = None, + store: ModelRecordServiceBase, + install: ModelInstallServiceBase, + load: ModelLoadServiceBase, ): - if context.services.queue.is_canceled(context.graph_execution_state_id): - raise CanceledException() - - if model_info: - context.services.events.emit_model_load_completed( - queue_id=context.queue_id, - queue_item_id=context.queue_item_id, - queue_batch_id=context.queue_batch_id, - graph_execution_state_id=context.graph_execution_state_id, - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=submodel, - model_info=model_info, - ) - else: - context.services.events.emit_model_load_started( - queue_id=context.queue_id, - queue_item_id=context.queue_item_id, - queue_batch_id=context.queue_batch_id, - graph_execution_state_id=context.graph_execution_state_id, - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=submodel, - ) + self._store = store + self._install = install + self._load = load @property - def logger(self): - return self.mgr.logger + def store(self) -> ModelRecordServiceBase: + return self._store - def heuristic_import( - self, - items_to_import: set[str], - prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None, - ) -> dict[str, AddModelResult]: - """Import a list of paths, repo_ids or URLs. Returns the set of - successfully imported items. - :param items_to_import: Set of strings corresponding to models to be imported. - :param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType. - - The prediction type helper is necessary to distinguish between - models based on Stable Diffusion 2 Base (requiring - SchedulerPredictionType.Epsilson) and Stable Diffusion 768 - (requiring SchedulerPredictionType.VPrediction). It is - generally impossible to do this programmatically, so the - prediction_type_helper usually asks the user to choose. + @property + def install(self) -> ModelInstallServiceBase: + return self._install - The result is a set of successfully installed models. Each element - of the set is a dict corresponding to the newly-created OmegaConf stanza for - that model. - """ - return self.mgr.heuristic_import(items_to_import, prediction_type_helper) + @property + def load(self) -> ModelLoadServiceBase: + return self._load - def merge_models( - self, - model_names: List[str] = Field( - default=None, min_length=2, max_length=3, description="List of model names to merge" - ), - base_model: Union[BaseModelType, str] = Field( - default=None, description="Base model shared by all models to be merged" - ), - merged_model_name: str = Field(default=None, description="Name of destination model after merging"), - alpha: float = 0.5, - interp: Optional[MergeInterpolationMethod] = None, - force: bool = False, - merge_dest_directory: Optional[Path] = Field( - default=None, description="Optional directory location for merged model" - ), - ) -> AddModelResult: - """ - Merge two to three diffusrs pipeline models and save as a new model. - :param model_names: List of 2-3 models to merge - :param base_model: Base model to use for all models - :param merged_model_name: Name of destination merged model - :param alpha: Alpha strength to apply to 2d and 3d model - :param interp: Interpolation method. None (default) - :param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended) - """ - merger = ModelMerger(self.mgr) - try: - result = merger.merge_diffusion_models_and_save( - model_names=model_names, - base_model=base_model, - merged_model_name=merged_model_name, - alpha=alpha, - interp=interp, - force=force, - merge_dest_directory=merge_dest_directory, - ) - except AssertionError as e: - raise ValueError(e) - return result + def start(self, invoker: Invoker) -> None: + for service in [self._store, self._install, self._load]: + if hasattr(service, "start"): + service.start(invoker) - def search_for_models(self, directory: Path) -> List[Path]: - """ - Return list of all models found in the designated directory. - """ - search = FindModels([directory], self.logger) - return search.list_models() + def stop(self, invoker: Invoker) -> None: + for service in [self._store, self._install, self._load]: + if hasattr(service, "stop"): + service.stop(invoker) - def sync_to_config(self): - """ - Re-read models.yaml, rescan the models directory, and reimport models - in the autoimport directories. Call after making changes outside the - model manager API. + @classmethod + def build_model_manager( + cls, + app_config: InvokeAIAppConfig, + db: SqliteDatabase, + download_queue: DownloadQueueServiceBase, + events: EventServiceBase, + ) -> Self: """ - return self.mgr.sync_to_config() + Construct the model manager service instance. - def list_checkpoint_configs(self) -> List[Path]: + For simplicity, use this class method rather than the __init__ constructor. """ - List the checkpoint config paths from ROOT/configs/stable-diffusion. - """ - config = self.mgr.app_config - conf_path = config.legacy_conf_path - root_path = config.root_path - return [(conf_path / x).relative_to(root_path) for x in conf_path.glob("**/*.yaml")] + logger = InvokeAILogger.get_logger(cls.__name__) + logger.setLevel(app_config.log_level.upper()) - def rename_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - new_name: Optional[str] = None, - new_base: Optional[BaseModelType] = None, - ): - """ - Rename the indicated model. Can provide a new name and/or a new base. - :param model_name: Current name of the model - :param base_model: Current base of the model - :param model_type: Model type (can't be changed) - :param new_name: New name for the model - :param new_base: New base for the model - """ - self.mgr.rename_model( - base_model=base_model, - model_type=model_type, - model_name=model_name, - new_name=new_name, - new_base=new_base, + ram_cache = ModelCache( + max_cache_size=app_config.ram_cache_size, max_vram_cache_size=app_config.vram_cache_size, logger=logger + ) + convert_cache = ModelConvertCache( + cache_path=app_config.models_convert_cache_path, max_size=app_config.convert_cache_size + ) + record_store = ModelRecordServiceSQL(db=db) + loader = ModelLoadService( + app_config=app_config, + record_store=record_store, + ram_cache=ram_cache, + convert_cache=convert_cache, + ) + record_store._loader = loader # yeah, there is a circular reference here + installer = ModelInstallService( + app_config=app_config, + record_store=record_store, + download_queue=download_queue, + metadata_store=ModelMetadataStore(db=db), + event_bus=events, ) + return cls(store=record_store, install=installer, load=loader) diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index 57597570cde..b2eacc524b7 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -11,7 +11,12 @@ from pydantic import BaseModel, Field from invokeai.app.services.shared.pagination import PaginatedResults -from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType +from invokeai.backend.model_manager import ( + AnyModelConfig, + BaseModelType, + ModelFormat, + ModelType, +) from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore @@ -146,7 +151,7 @@ def list_models( @abstractmethod def exists(self, key: str) -> bool: """ - Return True if a model with the indicated key exists in the databse. + Return True if a model with the indicated key exists in the database. :param key: Unique key for the model to be deleted """ diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index 4512da5d413..84a14123838 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -73,12 +73,11 @@ def __init__(self, db: SqliteDatabase): """ Initialize a new object from preexisting sqlite3 connection and threading lock objects. - :param conn: sqlite3 connection object - :param lock: threading Lock object + :param db: Sqlite connection object """ super().__init__() self._db = db - self._cursor = self._db.conn.cursor() + self._cursor = db.conn.cursor() @property def db(self) -> SqliteDatabase: @@ -199,7 +198,7 @@ def get_model(self, key: str) -> AnyModelConfig: with self._db.lock: self._cursor.execute( """--sql - SELECT config FROM model_config + SELECT config, strftime('%s',updated_at) FROM model_config WHERE id=?; """, (key,), @@ -207,7 +206,7 @@ def get_model(self, key: str) -> AnyModelConfig: rows = self._cursor.fetchone() if not rows: raise UnknownModelException("model not found") - model = ModelConfigFactory.make_config(json.loads(rows[0])) + model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1]) return model def exists(self, key: str) -> bool: @@ -265,12 +264,14 @@ def search_by_attr( with self._db.lock: self._cursor.execute( f"""--sql - select config FROM model_config + select config, strftime('%s',updated_at) FROM model_config {where}; """, tuple(bindings), ) - results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()] + results = [ + ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall() + ] return results def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]: @@ -279,12 +280,14 @@ def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]: with self._db.lock: self._cursor.execute( """--sql - SELECT config FROM model_config + SELECT config, strftime('%s',updated_at) FROM model_config WHERE path=?; """, (str(path),), ) - results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()] + results = [ + ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall() + ] return results def search_by_hash(self, hash: str) -> List[AnyModelConfig]: @@ -293,12 +296,14 @@ def search_by_hash(self, hash: str) -> List[AnyModelConfig]: with self._db.lock: self._cursor.execute( """--sql - SELECT config FROM model_config + SELECT config, strftime('%s',updated_at) FROM model_config WHERE original_hash=?; """, (hash,), ) - results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()] + results = [ + ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall() + ] return results @property diff --git a/invokeai/app/services/object_serializer/object_serializer_base.py b/invokeai/app/services/object_serializer/object_serializer_base.py new file mode 100644 index 00000000000..ff19b4a039d --- /dev/null +++ b/invokeai/app/services/object_serializer/object_serializer_base.py @@ -0,0 +1,44 @@ +from abc import ABC, abstractmethod +from typing import Callable, Generic, TypeVar + +T = TypeVar("T") + + +class ObjectSerializerBase(ABC, Generic[T]): + """Saves and loads arbitrary python objects.""" + + def __init__(self) -> None: + self._on_deleted_callbacks: list[Callable[[str], None]] = [] + + @abstractmethod + def load(self, name: str) -> T: + """ + Loads the object. + :param name: The name of the object to load. + :raises ObjectNotFoundError: if the object is not found + """ + pass + + @abstractmethod + def save(self, obj: T) -> str: + """ + Saves the object, returning its name. + :param obj: The object to save. + """ + pass + + @abstractmethod + def delete(self, name: str) -> None: + """ + Deletes the object, if it exists. + :param name: The name of the object to delete. + """ + pass + + def on_deleted(self, on_deleted: Callable[[str], None]) -> None: + """Register a callback for when an object is deleted""" + self._on_deleted_callbacks.append(on_deleted) + + def _on_deleted(self, name: str) -> None: + for callback in self._on_deleted_callbacks: + callback(name) diff --git a/invokeai/app/services/object_serializer/object_serializer_common.py b/invokeai/app/services/object_serializer/object_serializer_common.py new file mode 100644 index 00000000000..7057386541f --- /dev/null +++ b/invokeai/app/services/object_serializer/object_serializer_common.py @@ -0,0 +1,5 @@ +class ObjectNotFoundError(KeyError): + """Raised when an object is not found while loading""" + + def __init__(self, name: str) -> None: + super().__init__(f"Object with name {name} not found") diff --git a/invokeai/app/services/object_serializer/object_serializer_disk.py b/invokeai/app/services/object_serializer/object_serializer_disk.py new file mode 100644 index 00000000000..935fec30605 --- /dev/null +++ b/invokeai/app/services/object_serializer/object_serializer_disk.py @@ -0,0 +1,85 @@ +import tempfile +import typing +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Optional, TypeVar + +import torch + +from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase +from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError +from invokeai.app.util.misc import uuid_string + +if TYPE_CHECKING: + from invokeai.app.services.invoker import Invoker + + +T = TypeVar("T") + + +@dataclass +class DeleteAllResult: + deleted_count: int + freed_space_bytes: float + + +class ObjectSerializerDisk(ObjectSerializerBase[T]): + """Disk-backed storage for arbitrary python objects. Serialization is handled by `torch.save` and `torch.load`. + + :param output_dir: The folder where the serialized objects will be stored + :param ephemeral: If True, objects will be stored in a temporary directory inside the given output_dir and cleaned up on exit + """ + + def __init__(self, output_dir: Path, ephemeral: bool = False): + super().__init__() + self._ephemeral = ephemeral + self._base_output_dir = output_dir + self._base_output_dir.mkdir(parents=True, exist_ok=True) + # Must specify `ignore_cleanup_errors` to avoid fatal errors during cleanup on Windows + self._tempdir = ( + tempfile.TemporaryDirectory(dir=self._base_output_dir, ignore_cleanup_errors=True) if ephemeral else None + ) + self._output_dir = Path(self._tempdir.name) if self._tempdir else self._base_output_dir + self.__obj_class_name: Optional[str] = None + + def load(self, name: str) -> T: + file_path = self._get_path(name) + try: + return torch.load(file_path) # pyright: ignore [reportUnknownMemberType] + except FileNotFoundError as e: + raise ObjectNotFoundError(name) from e + + def save(self, obj: T) -> str: + name = self._new_name() + file_path = self._get_path(name) + torch.save(obj, file_path) # pyright: ignore [reportUnknownMemberType] + return name + + def delete(self, name: str) -> None: + file_path = self._get_path(name) + file_path.unlink() + + @property + def _obj_class_name(self) -> str: + if not self.__obj_class_name: + # `__orig_class__` is not available in the constructor for some technical, undoubtedly very pythonic reason + self.__obj_class_name = typing.get_args(self.__orig_class__)[0].__name__ # pyright: ignore [reportUnknownMemberType, reportAttributeAccessIssue] + return self.__obj_class_name + + def _get_path(self, name: str) -> Path: + return self._output_dir / name + + def _new_name(self) -> str: + return f"{self._obj_class_name}_{uuid_string()}" + + def _tempdir_cleanup(self) -> None: + """Calls `cleanup` on the temporary directory, if it exists.""" + if self._tempdir: + self._tempdir.cleanup() + + def __del__(self) -> None: + # In case the service is not properly stopped, clean up the temporary directory when the class instance is GC'd. + self._tempdir_cleanup() + + def stop(self, invoker: "Invoker") -> None: + self._tempdir_cleanup() diff --git a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py new file mode 100644 index 00000000000..b361259a4b1 --- /dev/null +++ b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py @@ -0,0 +1,65 @@ +from queue import Queue +from typing import TYPE_CHECKING, Optional, TypeVar + +from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase + +T = TypeVar("T") + +if TYPE_CHECKING: + from invokeai.app.services.invoker import Invoker + + +class ObjectSerializerForwardCache(ObjectSerializerBase[T]): + """ + Provides a LRU cache for an instance of `ObjectSerializerBase`. + Saving an object to the cache always writes through to the underlying storage. + """ + + def __init__(self, underlying_storage: ObjectSerializerBase[T], max_cache_size: int = 20): + super().__init__() + self._underlying_storage = underlying_storage + self._cache: dict[str, T] = {} + self._cache_ids = Queue[str]() + self._max_cache_size = max_cache_size + + def start(self, invoker: "Invoker") -> None: + self._invoker = invoker + start_op = getattr(self._underlying_storage, "start", None) + if callable(start_op): + start_op(invoker) + + def stop(self, invoker: "Invoker") -> None: + self._invoker = invoker + stop_op = getattr(self._underlying_storage, "stop", None) + if callable(stop_op): + stop_op(invoker) + + def load(self, name: str) -> T: + cache_item = self._get_cache(name) + if cache_item is not None: + return cache_item + + obj = self._underlying_storage.load(name) + self._set_cache(name, obj) + return obj + + def save(self, obj: T) -> str: + name = self._underlying_storage.save(obj) + self._set_cache(name, obj) + return name + + def delete(self, name: str) -> None: + self._underlying_storage.delete(name) + if name in self._cache: + del self._cache[name] + self._on_deleted(name) + + def _get_cache(self, name: str) -> Optional[T]: + return None if name not in self._cache else self._cache[name] + + def _set_cache(self, name: str, data: T): + if name not in self._cache: + self._cache[name] = data + self._cache_ids.put(name) + if self._cache_ids.qsize() > self._max_cache_size: + self._cache.pop(self._cache_ids.get()) diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index 1acf165abac..3df230f5ee7 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -13,14 +13,11 @@ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - Input, - InputField, - InvocationContext, - OutputField, - UIType, invocation, invocation_output, ) +from invokeai.app.invocations.fields import Input, InputField, OutputField, UIType +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.misc import uuid_string # in 3.10 this would be "from types import NoneType" diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py new file mode 100644 index 00000000000..089d09f825c --- /dev/null +++ b/invokeai/app/services/shared/invocation_context.py @@ -0,0 +1,461 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Optional + +from PIL.Image import Image +from torch import Tensor + +from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata +from invokeai.app.services.boards.boards_common import BoardDTO +from invokeai.app.services.config.config_default import InvokeAIAppConfig +from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin +from invokeai.app.services.images.images_common import ImageDTO +from invokeai.app.services.invocation_services import InvocationServices +from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID +from invokeai.app.util.step_callback import stable_diffusion_step_callback +from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType +from invokeai.backend.model_manager.load.load_base import LoadedModel +from invokeai.backend.model_manager.metadata.metadata_base import AnyModelRepoMetadata +from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData + +if TYPE_CHECKING: + from invokeai.app.invocations.baseinvocation import BaseInvocation + +""" +The InvocationContext provides access to various services and data about the current invocation. + +We do not provide the invocation services directly, as their methods are both dangerous and +inconvenient to use. + +For example: +- The `images` service allows nodes to delete or unsafely modify existing images. +- The `configuration` service allows nodes to change the app's config at runtime. +- The `events` service allows nodes to emit arbitrary events. + +Wrapping these services provides a simpler and safer interface for nodes to use. + +When a node executes, a fresh `InvocationContext` is built for it, ensuring nodes cannot interfere +with each other. + +Many of the wrappers have the same signature as the methods they wrap. This allows us to write +user-facing docstrings and not need to go and update the internal services to match. + +Note: The docstrings are in weird places, but that's where they must be to get IDEs to see them. +""" + + +@dataclass +class InvocationContextData: + invocation: "BaseInvocation" + """The invocation that is being executed.""" + session_id: str + """The session that is being executed.""" + queue_id: str + """The queue in which the session is being executed.""" + source_node_id: str + """The ID of the node from which the currently executing invocation was prepared.""" + queue_item_id: int + """The ID of the queue item that is being executed.""" + batch_id: str + """The ID of the batch that is being executed.""" + workflow: Optional[WorkflowWithoutID] = None + """The workflow associated with this queue item, if any.""" + + +class InvocationContextInterface: + def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None: + self._services = services + self._context_data = context_data + + +class BoardsInterface(InvocationContextInterface): + def create(self, board_name: str) -> BoardDTO: + """ + Creates a board. + + :param board_name: The name of the board to create. + """ + return self._services.boards.create(board_name) + + def get_dto(self, board_id: str) -> BoardDTO: + """ + Gets a board DTO. + + :param board_id: The ID of the board to get. + """ + return self._services.boards.get_dto(board_id) + + def get_all(self) -> list[BoardDTO]: + """ + Gets all boards. + """ + return self._services.boards.get_all() + + def add_image_to_board(self, board_id: str, image_name: str) -> None: + """ + Adds an image to a board. + + :param board_id: The ID of the board to add the image to. + :param image_name: The name of the image to add to the board. + """ + return self._services.board_images.add_image_to_board(board_id, image_name) + + def get_all_image_names_for_board(self, board_id: str) -> list[str]: + """ + Gets all image names for a board. + + :param board_id: The ID of the board to get the image names for. + """ + return self._services.board_images.get_all_board_image_names_for_board(board_id) + + +class LoggerInterface(InvocationContextInterface): + def debug(self, message: str) -> None: + """ + Logs a debug message. + + :param message: The message to log. + """ + self._services.logger.debug(message) + + def info(self, message: str) -> None: + """ + Logs an info message. + + :param message: The message to log. + """ + self._services.logger.info(message) + + def warning(self, message: str) -> None: + """ + Logs a warning message. + + :param message: The message to log. + """ + self._services.logger.warning(message) + + def error(self, message: str) -> None: + """ + Logs an error message. + + :param message: The message to log. + """ + self._services.logger.error(message) + + +class ImagesInterface(InvocationContextInterface): + def save( + self, + image: Image, + board_id: Optional[str] = None, + image_category: ImageCategory = ImageCategory.GENERAL, + metadata: Optional[MetadataField] = None, + ) -> ImageDTO: + """ + Saves an image, returning its DTO. + + If the current queue item has a workflow or metadata, it is automatically saved with the image. + + :param image: The image to save, as a PIL image. + :param board_id: The board ID to add the image to, if it should be added. It the invocation \ + inherits from `WithBoard`, that board will be used automatically. **Use this only if \ + you want to override or provide a board manually!** + :param image_category: The category of the image. Only the GENERAL category is added \ + to the gallery. + :param metadata: The metadata to save with the image, if it should have any. If the \ + invocation inherits from `WithMetadata`, that metadata will be used automatically. \ + **Use this only if you want to override or provide metadata manually!** + """ + + # If `metadata` is provided directly, use that. Else, use the metadata provided by `WithMetadata`, falling back to None. + metadata_ = None + if metadata: + metadata_ = metadata + elif isinstance(self._context_data.invocation, WithMetadata): + metadata_ = self._context_data.invocation.metadata + + # If `board_id` is provided directly, use that. Else, use the board provided by `WithBoard`, falling back to None. + board_id_ = None + if board_id: + board_id_ = board_id + elif isinstance(self._context_data.invocation, WithBoard) and self._context_data.invocation.board: + board_id_ = self._context_data.invocation.board.board_id + + return self._services.images.create( + image=image, + is_intermediate=self._context_data.invocation.is_intermediate, + image_category=image_category, + board_id=board_id_, + metadata=metadata_, + image_origin=ResourceOrigin.INTERNAL, + workflow=self._context_data.workflow, + session_id=self._context_data.session_id, + node_id=self._context_data.invocation.id, + ) + + def get_pil(self, image_name: str) -> Image: + """ + Gets an image as a PIL Image object. + + :param image_name: The name of the image to get. + """ + return self._services.images.get_pil_image(image_name) + + def get_metadata(self, image_name: str) -> Optional[MetadataField]: + """ + Gets an image's metadata, if it has any. + + :param image_name: The name of the image to get the metadata for. + """ + return self._services.images.get_metadata(image_name) + + def get_dto(self, image_name: str) -> ImageDTO: + """ + Gets an image as an ImageDTO object. + + :param image_name: The name of the image to get. + """ + return self._services.images.get_dto(image_name) + + +class TensorsInterface(InvocationContextInterface): + def save(self, tensor: Tensor) -> str: + """ + Saves a tensor, returning its name. + + :param tensor: The tensor to save. + """ + + name = self._services.tensors.save(obj=tensor) + return name + + def load(self, name: str) -> Tensor: + """ + Loads a tensor by name. + + :param name: The name of the tensor to load. + """ + return self._services.tensors.load(name) + + +class ConditioningInterface(InvocationContextInterface): + def save(self, conditioning_data: ConditioningFieldData) -> str: + """ + Saves a conditioning data object, returning its name. + + :param conditioning_context_data: The conditioning data to save. + """ + + name = self._services.conditioning.save(obj=conditioning_data) + return name + + def load(self, name: str) -> ConditioningFieldData: + """ + Loads conditioning data by name. + + :param name: The name of the conditioning data to load. + """ + + return self._services.conditioning.load(name) + + +class ModelsInterface(InvocationContextInterface): + def exists(self, key: str) -> bool: + """ + Checks if a model exists. + + :param key: The key of the model. + """ + return self._services.model_manager.store.exists(key) + + def load(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel: + """ + Loads a model. + + :param key: The key of the model. + :param submodel_type: The submodel of the model to get. + :returns: An object representing the loaded model. + """ + + # The model manager emits events as it loads the model. It needs the context data to build + # the event payloads. + + return self._services.model_manager.load.load_model_by_key( + key=key, submodel_type=submodel_type, context_data=self._context_data + ) + + def load_by_attrs( + self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None + ) -> LoadedModel: + """ + Loads a model by its attributes. + + :param model_name: Name of to be fetched. + :param base_model: Base model + :param model_type: Type of the model + :param submodel: For main (pipeline models), the submodel to fetch + """ + return self._services.model_manager.load.load_model_by_attr( + model_name=model_name, + base_model=base_model, + model_type=model_type, + submodel=submodel, + context_data=self._context_data, + ) + + def get_config(self, key: str) -> AnyModelConfig: + """ + Gets a model's info, an dict-like object. + + :param key: The key of the model. + """ + return self._services.model_manager.store.get_model(key=key) + + def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]: + """ + Gets a model's metadata, if it has any. + + :param key: The key of the model. + """ + return self._services.model_manager.store.get_metadata(key=key) + + def search_by_path(self, path: Path) -> list[AnyModelConfig]: + """ + Searches for models by path. + + :param path: The path to search for. + """ + return self._services.model_manager.store.search_by_path(path) + + def search_by_attrs( + self, + model_name: Optional[str] = None, + base_model: Optional[BaseModelType] = None, + model_type: Optional[ModelType] = None, + model_format: Optional[ModelFormat] = None, + ) -> list[AnyModelConfig]: + """ + Searches for models by attributes. + + :param model_name: Name of to be fetched. + :param base_model: Base model + :param model_type: Type of the model + :param submodel: For main (pipeline models), the submodel to fetch + """ + + return self._services.model_manager.store.search_by_attr( + model_name=model_name, + base_model=base_model, + model_type=model_type, + model_format=model_format, + ) + + +class ConfigInterface(InvocationContextInterface): + def get(self) -> InvokeAIAppConfig: + """Gets the app's config.""" + + return self._services.configuration.get_config() + + +class UtilInterface(InvocationContextInterface): + def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None: + """ + The step callback emits a progress event with the current step, the total number of + steps, a preview image, and some other internal metadata. + + This should be called after each denoising step. + + :param intermediate_state: The intermediate state of the diffusion pipeline. + :param base_model: The base model for the current denoising step. + """ + + # The step callback needs access to the events and the invocation queue services, but this + # represents a dangerous level of access. + # + # We wrap the step callback so that nodes do not have direct access to these services. + + stable_diffusion_step_callback( + context_data=self._context_data, + intermediate_state=intermediate_state, + base_model=base_model, + invocation_queue=self._services.queue, + events=self._services.events, + ) + + +class InvocationContext: + """ + The `InvocationContext` provides access to various services and data for the current invocation. + """ + + def __init__( + self, + images: ImagesInterface, + tensors: TensorsInterface, + conditioning: ConditioningInterface, + models: ModelsInterface, + logger: LoggerInterface, + config: ConfigInterface, + util: UtilInterface, + boards: BoardsInterface, + context_data: InvocationContextData, + services: InvocationServices, + ) -> None: + self.images = images + """Provides methods to save, get and update images and their metadata.""" + self.tensors = tensors + """Provides methods to save and get tensors, including image, noise, masks, and masked images.""" + self.conditioning = conditioning + """Provides methods to save and get conditioning data.""" + self.models = models + """Provides methods to check if a model exists, get a model, and get a model's info.""" + self.logger = logger + """Provides access to the app logger.""" + self.config = config + """Provides access to the app's config.""" + self.util = util + """Provides utility methods.""" + self.boards = boards + """Provides methods to interact with boards.""" + self._data = context_data + """Provides data about the current queue item and invocation. This is an internal API and may change without warning.""" + self._services = services + """Provides access to the full application services. This is an internal API and may change without warning.""" + + +def build_invocation_context( + services: InvocationServices, + context_data: InvocationContextData, +) -> InvocationContext: + """ + Builds the invocation context for a specific invocation execution. + + :param invocation_services: The invocation services to wrap. + :param invocation_context_data: The invocation context data. + """ + + logger = LoggerInterface(services=services, context_data=context_data) + images = ImagesInterface(services=services, context_data=context_data) + tensors = TensorsInterface(services=services, context_data=context_data) + models = ModelsInterface(services=services, context_data=context_data) + config = ConfigInterface(services=services, context_data=context_data) + util = UtilInterface(services=services, context_data=context_data) + conditioning = ConditioningInterface(services=services, context_data=context_data) + boards = BoardsInterface(services=services, context_data=context_data) + + ctx = InvocationContext( + images=images, + logger=logger, + config=config, + tensors=tensors, + models=models, + context_data=context_data, + util=util, + conditioning=conditioning, + services=services, + boards=boards, + ) + + return ctx diff --git a/invokeai/app/services/shared/sqlite/sqlite_util.py b/invokeai/app/services/shared/sqlite/sqlite_util.py index 6079b3f08d7..681886eacd3 100644 --- a/invokeai/app/services/shared/sqlite/sqlite_util.py +++ b/invokeai/app/services/shared/sqlite/sqlite_util.py @@ -8,6 +8,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_3 import build_migration_3 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_5 import build_migration_5 +from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import build_migration_6 from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator @@ -33,6 +34,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto migrator.register_migration(build_migration_3(app_config=config, logger=logger)) migrator.register_migration(build_migration_4()) migrator.register_migration(build_migration_5()) + migrator.register_migration(build_migration_6()) migrator.run_migrations() return db diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_6.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_6.py new file mode 100644 index 00000000000..1f9ac56518c --- /dev/null +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_6.py @@ -0,0 +1,62 @@ +import sqlite3 + +from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration + + +class Migration6Callback: + def __call__(self, cursor: sqlite3.Cursor) -> None: + self._recreate_model_triggers(cursor) + self._delete_ip_adapters(cursor) + + def _recreate_model_triggers(self, cursor: sqlite3.Cursor) -> None: + """ + Adds the timestamp trigger to the model_config table. + + This trigger was inadvertently dropped in earlier migration scripts. + """ + + cursor.execute( + """--sql + CREATE TRIGGER IF NOT EXISTS model_config_updated_at + AFTER UPDATE + ON model_config FOR EACH ROW + BEGIN + UPDATE model_config SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW') + WHERE id = old.id; + END; + """ + ) + + def _delete_ip_adapters(self, cursor: sqlite3.Cursor) -> None: + """ + Delete all the IP adapters. + + The model manager will automatically find and re-add them after the migration + is done. This allows the manager to add the correct image encoder to their + configuration records. + """ + + cursor.execute( + """--sql + DELETE FROM model_config + WHERE type='ip_adapter'; + """ + ) + + +def build_migration_6() -> Migration: + """ + Build the migration from database version 5 to 6. + + This migration does the following: + - Adds the model_config_updated_at trigger if it does not exist + - Delete all ip_adapter models so that the model prober can find and + update with the correct image processor model. + """ + migration_6 = Migration( + from_version=5, + to_version=6, + callback=Migration6Callback(), + ) + + return migration_6 diff --git a/invokeai/app/shared/fields.py b/invokeai/app/shared/fields.py deleted file mode 100644 index 3e841ffbf22..00000000000 --- a/invokeai/app/shared/fields.py +++ /dev/null @@ -1,67 +0,0 @@ -class FieldDescriptions: - denoising_start = "When to start denoising, expressed a percentage of total steps" - denoising_end = "When to stop denoising, expressed a percentage of total steps" - cfg_scale = "Classifier-Free Guidance scale" - cfg_rescale_multiplier = "Rescale multiplier for CFG guidance, used for models trained with zero-terminal SNR" - scheduler = "Scheduler to use during inference" - positive_cond = "Positive conditioning tensor" - negative_cond = "Negative conditioning tensor" - noise = "Noise tensor" - clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count" - unet = "UNet (scheduler, LoRAs)" - vae = "VAE" - cond = "Conditioning tensor" - controlnet_model = "ControlNet model to load" - vae_model = "VAE model to load" - lora_model = "LoRA model to load" - main_model = "Main model (UNet, VAE, CLIP) to load" - sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load" - sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load" - onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load" - lora_weight = "The weight at which the LoRA is applied to each model" - compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor" - raw_prompt = "Raw prompt text (no parsing)" - sdxl_aesthetic = "The aesthetic score to apply to the conditioning tensor" - skipped_layers = "Number of layers to skip in text encoder" - seed = "Seed for random number generation" - steps = "Number of steps to run" - width = "Width of output (px)" - height = "Height of output (px)" - control = "ControlNet(s) to apply" - ip_adapter = "IP-Adapter to apply" - t2i_adapter = "T2I-Adapter(s) to apply" - denoised_latents = "Denoised latents tensor" - latents = "Latents tensor" - strength = "Strength of denoising (proportional to steps)" - metadata = "Optional metadata to be saved with the image" - metadata_collection = "Collection of Metadata" - metadata_item_polymorphic = "A single metadata item or collection of metadata items" - metadata_item_label = "Label for this metadata item" - metadata_item_value = "The value for this metadata item (may be any type)" - workflow = "Optional workflow to be saved with the image" - interp_mode = "Interpolation mode" - torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)" - fp32 = "Whether or not to use full float32 precision" - precision = "Precision to use" - tiled = "Processing using overlapping tiles (reduce memory consumption)" - detect_res = "Pixel resolution for detection" - image_res = "Pixel resolution for output image" - safe_mode = "Whether or not to use safe mode" - scribble_mode = "Whether or not to use scribble mode" - scale_factor = "The factor by which to scale" - blend_alpha = ( - "Blending factor. 0.0 = use input A only, 1.0 = use input B only, 0.5 = 50% mix of input A and input B." - ) - num_1 = "The first number" - num_2 = "The second number" - mask = "The mask to use for the operation" - board = "The board to save the image to" - image = "The image to process" - tile_size = "Tile size" - inclusive_low = "The inclusive low value" - exclusive_high = "The exclusive high value" - decimal_places = "The number of decimal places to round to" - freeu_s1 = 'Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.' - freeu_s2 = 'Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.' - freeu_b1 = "Scaling factor for stage 1 to amplify the contributions of backbone features." - freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features." diff --git a/invokeai/app/shared/models.py b/invokeai/app/shared/models.py index ed68cb287e3..1a11b480cc5 100644 --- a/invokeai/app/shared/models.py +++ b/invokeai/app/shared/models.py @@ -1,6 +1,6 @@ from pydantic import BaseModel, Field -from invokeai.app.shared.fields import FieldDescriptions +from invokeai.app.invocations.fields import FieldDescriptions class FreeUConfig(BaseModel): diff --git a/invokeai/app/util/misc.py b/invokeai/app/util/misc.py index 910b05d8dde..da431929dbe 100644 --- a/invokeai/app/util/misc.py +++ b/invokeai/app/util/misc.py @@ -5,7 +5,7 @@ import numpy as np -def get_timestamp(): +def get_timestamp() -> int: return int(datetime.datetime.now(datetime.timezone.utc).timestamp()) @@ -20,16 +20,16 @@ def get_datetime_from_iso_timestamp(iso_timestamp: str) -> datetime.datetime: SEED_MAX = np.iinfo(np.uint32).max -def get_random_seed(): +def get_random_seed() -> int: rng = np.random.default_rng(seed=None) return int(rng.integers(0, SEED_MAX)) -def uuid_string(): +def uuid_string() -> str: res = uuid.uuid4() return str(res) -def is_optional(value: typing.Any): +def is_optional(value: typing.Any) -> bool: """Checks if a value is typed as Optional. Note that Optional is sugar for Union[x, None].""" return typing.get_origin(value) is typing.Union and type(None) in typing.get_args(value) diff --git a/invokeai/app/util/step_callback.py b/invokeai/app/util/step_callback.py index f166206d528..33d00ca3660 100644 --- a/invokeai/app/util/step_callback.py +++ b/invokeai/app/util/step_callback.py @@ -1,12 +1,18 @@ +from typing import TYPE_CHECKING + import torch from PIL import Image from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException, ProgressImage +from invokeai.backend.model_manager.config import BaseModelType -from ...backend.model_management.models import BaseModelType from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.util.util import image_to_dataURL -from ..invocations.baseinvocation import InvocationContext + +if TYPE_CHECKING: + from invokeai.app.services.events.events_base import EventServiceBase + from invokeai.app.services.invocation_queue.invocation_queue_base import InvocationQueueABC + from invokeai.app.services.shared.invocation_context import InvocationContextData def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None): @@ -25,13 +31,13 @@ def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix= def stable_diffusion_step_callback( - context: InvocationContext, + context_data: "InvocationContextData", intermediate_state: PipelineIntermediateState, - node: dict, - source_node_id: str, base_model: BaseModelType, -): - if context.services.queue.is_canceled(context.graph_execution_state_id): + invocation_queue: "InvocationQueueABC", + events: "EventServiceBase", +) -> None: + if invocation_queue.is_canceled(context_data.session_id): raise CanceledException # Some schedulers report not only the noisy latents at the current timestep, @@ -108,13 +114,13 @@ def stable_diffusion_step_callback( dataURL = image_to_dataURL(image, image_format="JPEG") - context.services.events.emit_generator_progress( - queue_id=context.queue_id, - queue_item_id=context.queue_item_id, - queue_batch_id=context.queue_batch_id, - graph_execution_state_id=context.graph_execution_state_id, - node=node, - source_node_id=source_node_id, + events.emit_generator_progress( + queue_id=context_data.queue_id, + queue_item_id=context_data.queue_item_id, + queue_batch_id=context_data.batch_id, + graph_execution_state_id=context_data.session_id, + node_id=context_data.invocation.id, + source_node_id=context_data.source_node_id, progress_image=ProgressImage(width=width, height=height, dataURL=dataURL), step=intermediate_state.step, order=intermediate_state.order, diff --git a/invokeai/backend/__init__.py b/invokeai/backend/__init__.py index ae9a12edbe2..9fe97ee525e 100644 --- a/invokeai/backend/__init__.py +++ b/invokeai/backend/__init__.py @@ -1,5 +1,3 @@ """ Initialization file for invokeai.backend """ -from .model_management import BaseModelType, ModelCache, ModelInfo, ModelManager, ModelType, SubModelType # noqa: F401 -from .model_management.models import SilenceWarnings # noqa: F401 diff --git a/invokeai/backend/embeddings/__init__.py b/invokeai/backend/embeddings/__init__.py new file mode 100644 index 00000000000..46ead533c4d --- /dev/null +++ b/invokeai/backend/embeddings/__init__.py @@ -0,0 +1,4 @@ +"""Initialization file for invokeai.backend.embeddings modules.""" + +# from .model_patcher import ModelPatcher +# __all__ = ["ModelPatcher"] diff --git a/invokeai/backend/embeddings/embedding_base.py b/invokeai/backend/embeddings/embedding_base.py new file mode 100644 index 00000000000..5e752a29e14 --- /dev/null +++ b/invokeai/backend/embeddings/embedding_base.py @@ -0,0 +1,12 @@ +"""Base class for LoRA and Textual Inversion models. + +The EmbeddingRaw class is the base class of LoRAModelRaw and TextualInversionModelRaw, +and is used for type checking of calls to the model patcher. + +The use of "Raw" here is a historical artifact, and carried forward in +order to avoid confusion. +""" + + +class EmbeddingModelRaw: + """Base class for LoRA and Textual Inversion models.""" diff --git a/invokeai/backend/embeddings/lora.py b/invokeai/backend/embeddings/lora.py new file mode 100644 index 00000000000..3c7ef074efe --- /dev/null +++ b/invokeai/backend/embeddings/lora.py @@ -0,0 +1,625 @@ +# Copyright (c) 2024 The InvokeAI Development team +"""LoRA model support.""" + +import bisect +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import torch +from safetensors.torch import load_file +from typing_extensions import Self + +from invokeai.backend.model_manager import BaseModelType + +from .embedding_base import EmbeddingModelRaw + + +class LoRALayerBase: + # rank: Optional[int] + # alpha: Optional[float] + # bias: Optional[torch.Tensor] + # layer_key: str + + # @property + # def scale(self): + # return self.alpha / self.rank if (self.alpha and self.rank) else 1.0 + + def __init__( + self, + layer_key: str, + values: Dict[str, torch.Tensor], + ): + if "alpha" in values: + self.alpha = values["alpha"].item() + else: + self.alpha = None + + if "bias_indices" in values and "bias_values" in values and "bias_size" in values: + self.bias: Optional[torch.Tensor] = torch.sparse_coo_tensor( + values["bias_indices"], + values["bias_values"], + tuple(values["bias_size"]), + ) + + else: + self.bias = None + + self.rank = None # set in layer implementation + self.layer_key = layer_key + + def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: + raise NotImplementedError() + + def calc_size(self) -> int: + model_size = 0 + for val in [self.bias]: + if val is not None: + model_size += val.nelement() * val.element_size() + return model_size + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> None: + if self.bias is not None: + self.bias = self.bias.to(device=device, dtype=dtype) + + +# TODO: find and debug lora/locon with bias +class LoRALayer(LoRALayerBase): + # up: torch.Tensor + # mid: Optional[torch.Tensor] + # down: torch.Tensor + + def __init__( + self, + layer_key: str, + values: Dict[str, torch.Tensor], + ): + super().__init__(layer_key, values) + + self.up = values["lora_up.weight"] + self.down = values["lora_down.weight"] + if "lora_mid.weight" in values: + self.mid: Optional[torch.Tensor] = values["lora_mid.weight"] + else: + self.mid = None + + self.rank = self.down.shape[0] + + def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: + if self.mid is not None: + up = self.up.reshape(self.up.shape[0], self.up.shape[1]) + down = self.down.reshape(self.down.shape[0], self.down.shape[1]) + weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down) + else: + weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1) + + return weight + + def calc_size(self) -> int: + model_size = super().calc_size() + for val in [self.up, self.mid, self.down]: + if val is not None: + model_size += val.nelement() * val.element_size() + return model_size + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> None: + super().to(device=device, dtype=dtype) + + self.up = self.up.to(device=device, dtype=dtype) + self.down = self.down.to(device=device, dtype=dtype) + + if self.mid is not None: + self.mid = self.mid.to(device=device, dtype=dtype) + + +class LoHALayer(LoRALayerBase): + # w1_a: torch.Tensor + # w1_b: torch.Tensor + # w2_a: torch.Tensor + # w2_b: torch.Tensor + # t1: Optional[torch.Tensor] = None + # t2: Optional[torch.Tensor] = None + + def __init__(self, layer_key: str, values: Dict[str, torch.Tensor]): + super().__init__(layer_key, values) + + self.w1_a = values["hada_w1_a"] + self.w1_b = values["hada_w1_b"] + self.w2_a = values["hada_w2_a"] + self.w2_b = values["hada_w2_b"] + + if "hada_t1" in values: + self.t1: Optional[torch.Tensor] = values["hada_t1"] + else: + self.t1 = None + + if "hada_t2" in values: + self.t2: Optional[torch.Tensor] = values["hada_t2"] + else: + self.t2 = None + + self.rank = self.w1_b.shape[0] + + def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: + if self.t1 is None: + weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b) + + else: + rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a) + rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a) + weight = rebuild1 * rebuild2 + + return weight + + def calc_size(self) -> int: + model_size = super().calc_size() + for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]: + if val is not None: + model_size += val.nelement() * val.element_size() + return model_size + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> None: + super().to(device=device, dtype=dtype) + + self.w1_a = self.w1_a.to(device=device, dtype=dtype) + self.w1_b = self.w1_b.to(device=device, dtype=dtype) + if self.t1 is not None: + self.t1 = self.t1.to(device=device, dtype=dtype) + + self.w2_a = self.w2_a.to(device=device, dtype=dtype) + self.w2_b = self.w2_b.to(device=device, dtype=dtype) + if self.t2 is not None: + self.t2 = self.t2.to(device=device, dtype=dtype) + + +class LoKRLayer(LoRALayerBase): + # w1: Optional[torch.Tensor] = None + # w1_a: Optional[torch.Tensor] = None + # w1_b: Optional[torch.Tensor] = None + # w2: Optional[torch.Tensor] = None + # w2_a: Optional[torch.Tensor] = None + # w2_b: Optional[torch.Tensor] = None + # t2: Optional[torch.Tensor] = None + + def __init__( + self, + layer_key: str, + values: Dict[str, torch.Tensor], + ): + super().__init__(layer_key, values) + + if "lokr_w1" in values: + self.w1: Optional[torch.Tensor] = values["lokr_w1"] + self.w1_a = None + self.w1_b = None + else: + self.w1 = None + self.w1_a = values["lokr_w1_a"] + self.w1_b = values["lokr_w1_b"] + + if "lokr_w2" in values: + self.w2: Optional[torch.Tensor] = values["lokr_w2"] + self.w2_a = None + self.w2_b = None + else: + self.w2 = None + self.w2_a = values["lokr_w2_a"] + self.w2_b = values["lokr_w2_b"] + + if "lokr_t2" in values: + self.t2: Optional[torch.Tensor] = values["lokr_t2"] + else: + self.t2 = None + + if "lokr_w1_b" in values: + self.rank = values["lokr_w1_b"].shape[0] + elif "lokr_w2_b" in values: + self.rank = values["lokr_w2_b"].shape[0] + else: + self.rank = None # unscaled + + def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: + w1: Optional[torch.Tensor] = self.w1 + if w1 is None: + assert self.w1_a is not None + assert self.w1_b is not None + w1 = self.w1_a @ self.w1_b + + w2 = self.w2 + if w2 is None: + if self.t2 is None: + assert self.w2_a is not None + assert self.w2_b is not None + w2 = self.w2_a @ self.w2_b + else: + w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b) + + if len(w2.shape) == 4: + w1 = w1.unsqueeze(2).unsqueeze(2) + w2 = w2.contiguous() + assert w1 is not None + assert w2 is not None + weight = torch.kron(w1, w2) + + return weight + + def calc_size(self) -> int: + model_size = super().calc_size() + for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]: + if val is not None: + model_size += val.nelement() * val.element_size() + return model_size + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> None: + super().to(device=device, dtype=dtype) + + if self.w1 is not None: + self.w1 = self.w1.to(device=device, dtype=dtype) + else: + assert self.w1_a is not None + assert self.w1_b is not None + self.w1_a = self.w1_a.to(device=device, dtype=dtype) + self.w1_b = self.w1_b.to(device=device, dtype=dtype) + + if self.w2 is not None: + self.w2 = self.w2.to(device=device, dtype=dtype) + else: + assert self.w2_a is not None + assert self.w2_b is not None + self.w2_a = self.w2_a.to(device=device, dtype=dtype) + self.w2_b = self.w2_b.to(device=device, dtype=dtype) + + if self.t2 is not None: + self.t2 = self.t2.to(device=device, dtype=dtype) + + +class FullLayer(LoRALayerBase): + # weight: torch.Tensor + + def __init__( + self, + layer_key: str, + values: Dict[str, torch.Tensor], + ): + super().__init__(layer_key, values) + + self.weight = values["diff"] + + if len(values.keys()) > 1: + _keys = list(values.keys()) + _keys.remove("diff") + raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}") + + self.rank = None # unscaled + + def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: + return self.weight + + def calc_size(self) -> int: + model_size = super().calc_size() + model_size += self.weight.nelement() * self.weight.element_size() + return model_size + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> None: + super().to(device=device, dtype=dtype) + + self.weight = self.weight.to(device=device, dtype=dtype) + + +class IA3Layer(LoRALayerBase): + # weight: torch.Tensor + # on_input: torch.Tensor + + def __init__( + self, + layer_key: str, + values: Dict[str, torch.Tensor], + ): + super().__init__(layer_key, values) + + self.weight = values["weight"] + self.on_input = values["on_input"] + + self.rank = None # unscaled + + def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: + weight = self.weight + if not self.on_input: + weight = weight.reshape(-1, 1) + assert orig_weight is not None + return orig_weight * weight + + def calc_size(self) -> int: + model_size = super().calc_size() + model_size += self.weight.nelement() * self.weight.element_size() + model_size += self.on_input.nelement() * self.on_input.element_size() + return model_size + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + super().to(device=device, dtype=dtype) + + self.weight = self.weight.to(device=device, dtype=dtype) + self.on_input = self.on_input.to(device=device, dtype=dtype) + + +AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer] + + +# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix +class LoRAModelRaw(EmbeddingModelRaw): # (torch.nn.Module): + _name: str + layers: Dict[str, AnyLoRALayer] + + def __init__( + self, + name: str, + layers: Dict[str, AnyLoRALayer], + ): + self._name = name + self.layers = layers + + @property + def name(self) -> str: + return self._name + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> None: + # TODO: try revert if exception? + for _key, layer in self.layers.items(): + layer.to(device=device, dtype=dtype) + + def calc_size(self) -> int: + model_size = 0 + for _, layer in self.layers.items(): + model_size += layer.calc_size() + return model_size + + @classmethod + def _convert_sdxl_keys_to_diffusers_format(cls, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """Convert the keys of an SDXL LoRA state_dict to diffusers format. + + The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in + diffusers format, then this function will have no effect. + + This function is adapted from: + https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409 + + Args: + state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict. + + Raises: + ValueError: If state_dict contains an unrecognized key, or not all keys could be converted. + + Returns: + Dict[str, Tensor]: The diffusers-format state_dict. + """ + converted_count = 0 # The number of Stability AI keys converted to diffusers format. + not_converted_count = 0 # The number of keys that were not converted. + + # Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes. + # For example, we want to efficiently find `input_blocks_4_1` in the list when searching for + # `input_blocks_4_1_proj_in`. + stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP) + stability_unet_keys.sort() + + new_state_dict = {} + for full_key, value in state_dict.items(): + if full_key.startswith("lora_unet_"): + search_key = full_key.replace("lora_unet_", "") + # Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix. + position = bisect.bisect_right(stability_unet_keys, search_key) + map_key = stability_unet_keys[position - 1] + # Now, check if the map_key *actually* matches the search_key. + if search_key.startswith(map_key): + new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key]) + new_state_dict[new_key] = value + converted_count += 1 + else: + new_state_dict[full_key] = value + not_converted_count += 1 + elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"): + # The CLIP text encoders have the same keys in both Stability AI and diffusers formats. + new_state_dict[full_key] = value + continue + else: + raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.") + + if converted_count > 0 and not_converted_count > 0: + raise ValueError( + f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count}," + f" not_converted={not_converted_count}" + ) + + return new_state_dict + + @classmethod + def from_checkpoint( + cls, + file_path: Union[str, Path], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + base_model: Optional[BaseModelType] = None, + ) -> Self: + device = device or torch.device("cpu") + dtype = dtype or torch.float32 + + if isinstance(file_path, str): + file_path = Path(file_path) + + model = cls( + name=file_path.stem, + layers={}, + ) + + if file_path.suffix == ".safetensors": + sd = load_file(file_path.absolute().as_posix(), device="cpu") + else: + sd = torch.load(file_path, map_location="cpu") + + state_dict = cls._group_state(sd) + + if base_model == BaseModelType.StableDiffusionXL: + state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict) + + for layer_key, values in state_dict.items(): + # lora and locon + if "lora_down.weight" in values: + layer: AnyLoRALayer = LoRALayer(layer_key, values) + + # loha + elif "hada_w1_b" in values: + layer = LoHALayer(layer_key, values) + + # lokr + elif "lokr_w1_b" in values or "lokr_w1" in values: + layer = LoKRLayer(layer_key, values) + + # diff + elif "diff" in values: + layer = FullLayer(layer_key, values) + + # ia3 + elif "weight" in values and "on_input" in values: + layer = IA3Layer(layer_key, values) + + else: + print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}") + raise Exception("Unknown lora format!") + + # lower memory consumption by removing already parsed layer values + state_dict[layer_key].clear() + + layer.to(device=device, dtype=dtype) + model.layers[layer_key] = layer + + return model + + @staticmethod + def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]: + state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {} + + for key, value in state_dict.items(): + stem, leaf = key.split(".", 1) + if stem not in state_dict_groupped: + state_dict_groupped[stem] = {} + state_dict_groupped[stem][leaf] = value + + return state_dict_groupped + + +# code from +# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32 +def make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]: + """Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format.""" + unet_conversion_map_layer = [] + + for i in range(3): # num_blocks is 3 in sdxl + # loop over downblocks/upblocks + for j in range(2): + # loop over resnets/attentions for downblocks + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." + sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + + if i < 3: + # no attention layers in down_blocks.3 + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + + for j in range(3): + # loop over resnets/attentions for upblocks + hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." + sd_up_res_prefix = f"output_blocks.{3*i + j}.0." + unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) + + # if i > 0: commentout for sdxl + # no attention layers in up_blocks.0 + hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." + sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." + unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) + + if i < 3: + # no downsample in down_blocks.3 + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." + sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + + # no upsample in up_blocks.3 + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl + unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) + + hf_mid_atn_prefix = "mid_block.attentions.0." + sd_mid_atn_prefix = "middle_block.1." + unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + + for j in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{j}." + sd_mid_res_prefix = f"middle_block.{2*j}." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + unet_conversion_map_resnet = [ + # (stable-diffusion, HF Diffusers) + ("in_layers.0.", "norm1."), + ("in_layers.2.", "conv1."), + ("out_layers.0.", "norm2."), + ("out_layers.3.", "conv2."), + ("emb_layers.1.", "time_emb_proj."), + ("skip_connection.", "conv_shortcut."), + ] + + unet_conversion_map = [] + for sd, hf in unet_conversion_map_layer: + if "resnets" in hf: + for sd_res, hf_res in unet_conversion_map_resnet: + unet_conversion_map.append((sd + sd_res, hf + hf_res)) + else: + unet_conversion_map.append((sd, hf)) + + for j in range(2): + hf_time_embed_prefix = f"time_embedding.linear_{j+1}." + sd_time_embed_prefix = f"time_embed.{j*2}." + unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix)) + + for j in range(2): + hf_label_embed_prefix = f"add_embedding.linear_{j+1}." + sd_label_embed_prefix = f"label_emb.0.{j*2}." + unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix)) + + unet_conversion_map.append(("input_blocks.0.0.", "conv_in.")) + unet_conversion_map.append(("out.0.", "conv_norm_out.")) + unet_conversion_map.append(("out.2.", "conv_out.")) + + return unet_conversion_map + + +SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = { + sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map() +} diff --git a/invokeai/backend/embeddings/model_patcher.py b/invokeai/backend/embeddings/model_patcher.py new file mode 100644 index 00000000000..bee8909c311 --- /dev/null +++ b/invokeai/backend/embeddings/model_patcher.py @@ -0,0 +1,498 @@ +# Copyright (c) 2024 Ryan Dick, Lincoln D. Stein, and the InvokeAI Development Team +"""These classes implement model patching with LoRAs and Textual Inversions.""" +from __future__ import annotations + +import pickle +from contextlib import contextmanager +from typing import Any, Dict, Iterator, List, Optional, Tuple + +import numpy as np +import torch +from diffusers import OnnxRuntimeModel, UNet2DConditionModel +from transformers import CLIPTextModel, CLIPTokenizer + +from invokeai.app.shared.models import FreeUConfig +from invokeai.backend.model_manager import AnyModel +from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init +from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel + +from .lora import LoRAModelRaw +from .textual_inversion import TextualInversionManager, TextualInversionModelRaw + +""" +loras = [ + (lora_model1, 0.7), + (lora_model2, 0.4), +] +with LoRAHelper.apply_lora_unet(unet, loras): + # unet with applied loras +# unmodified unet + +""" + + +# TODO: rename smth like ModelPatcher and add TI method? +class ModelPatcher: + @staticmethod + def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]: + assert "." not in lora_key + + if not lora_key.startswith(prefix): + raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}") + + module = model + module_key = "" + key_parts = lora_key[len(prefix) :].split("_") + + submodule_name = key_parts.pop(0) + + while len(key_parts) > 0: + try: + module = module.get_submodule(submodule_name) + module_key += "." + submodule_name + submodule_name = key_parts.pop(0) + except Exception: + submodule_name += "_" + key_parts.pop(0) + + module = module.get_submodule(submodule_name) + module_key = (module_key + "." + submodule_name).lstrip(".") + + return (module_key, module) + + @classmethod + @contextmanager + def apply_lora_unet( + cls, + unet: UNet2DConditionModel, + loras: Iterator[Tuple[LoRAModelRaw, float]], + ) -> None: + with cls.apply_lora(unet, loras, "lora_unet_"): + yield + + @classmethod + @contextmanager + def apply_lora_text_encoder( + cls, + text_encoder: CLIPTextModel, + loras: Iterator[Tuple[LoRAModelRaw, float]], + ) -> None: + with cls.apply_lora(text_encoder, loras, "lora_te_"): + yield + + @classmethod + @contextmanager + def apply_sdxl_lora_text_encoder( + cls, + text_encoder: CLIPTextModel, + loras: List[Tuple[LoRAModelRaw, float]], + ) -> None: + with cls.apply_lora(text_encoder, loras, "lora_te1_"): + yield + + @classmethod + @contextmanager + def apply_sdxl_lora_text_encoder2( + cls, + text_encoder: CLIPTextModel, + loras: List[Tuple[LoRAModelRaw, float]], + ) -> None: + with cls.apply_lora(text_encoder, loras, "lora_te2_"): + yield + + @classmethod + @contextmanager + def apply_lora( + cls, + model: AnyModel, + loras: Iterator[Tuple[LoRAModelRaw, float]], + prefix: str, + ) -> None: + original_weights = {} + try: + with torch.no_grad(): + for lora, lora_weight in loras: + # assert lora.device.type == "cpu" + for layer_key, layer in lora.layers.items(): + if not layer_key.startswith(prefix): + continue + + # TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This + # should be improved in the following ways: + # 1. The key mapping could be more-efficiently pre-computed. This would save time every time a + # LoRA model is applied. + # 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the + # intricacies of Stable Diffusion key resolution. It should just expect the input LoRA + # weights to have valid keys. + assert isinstance(model, torch.nn.Module) + module_key, module = cls._resolve_lora_key(model, layer_key, prefix) + + # All of the LoRA weight calculations will be done on the same device as the module weight. + # (Performance will be best if this is a CUDA device.) + device = module.weight.device + dtype = module.weight.dtype + + if module_key not in original_weights: + original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True) + + layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 + + # We intentionally move to the target device first, then cast. Experimentally, this was found to + # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the + # same thing in a single call to '.to(...)'. + layer.to(device=device) + layer.to(dtype=torch.float32) + # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA + # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed. + layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale) + layer.to(device=torch.device("cpu")) + + assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! + if module.weight.shape != layer_weight.shape: + # TODO: debug on lycoris + assert hasattr(layer_weight, "reshape") + layer_weight = layer_weight.reshape(module.weight.shape) + + assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! + module.weight += layer_weight.to(dtype=dtype) + + yield # wait for context manager exit + + finally: + assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule() + with torch.no_grad(): + for module_key, weight in original_weights.items(): + model.get_submodule(module_key).weight.copy_(weight) + + @classmethod + @contextmanager + def apply_ti( + cls, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + ti_list: List[Tuple[str, TextualInversionModelRaw]], + ) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]: + init_tokens_count = None + new_tokens_added = None + + # TODO: This is required since Transformers 4.32 see + # https://github.com/huggingface/transformers/pull/25088 + # More information by NVIDIA: + # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc + # This value might need to be changed in the future and take the GPUs model into account as there seem + # to be ideal values for different GPUS. This value is temporary! + # For references to the current discussion please see https://github.com/invoke-ai/InvokeAI/pull/4817 + pad_to_multiple_of = 8 + + try: + # HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a + # workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after + # exiting this `apply_ti(...)` context manager. + # + # In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`, + # but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs). + ti_tokenizer = pickle.loads(pickle.dumps(tokenizer)) + ti_manager = TextualInversionManager(ti_tokenizer) + init_tokens_count = text_encoder.resize_token_embeddings(None, pad_to_multiple_of).num_embeddings + + def _get_trigger(ti_name: str, index: int) -> str: + trigger = ti_name + if index > 0: + trigger += f"-!pad-{i}" + return f"<{trigger}>" + + def _get_ti_embedding(model_embeddings: torch.nn.Module, ti: TextualInversionModelRaw) -> torch.Tensor: + # for SDXL models, select the embedding that matches the text encoder's dimensions + if ti.embedding_2 is not None: + return ( + ti.embedding_2 + if ti.embedding_2.shape[1] == model_embeddings.weight.data[0].shape[0] + else ti.embedding + ) + else: + return ti.embedding + + # modify tokenizer + new_tokens_added = 0 + for ti_name, ti in ti_list: + ti_embedding = _get_ti_embedding(text_encoder.get_input_embeddings(), ti) + + for i in range(ti_embedding.shape[0]): + new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i)) + + # Modify text_encoder. + # resize_token_embeddings(...) constructs a new torch.nn.Embedding internally. Initializing the weights of + # this embedding is slow and unnecessary, so we wrap this step in skip_torch_weight_init() to save some + # time. + with skip_torch_weight_init(): + text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added, pad_to_multiple_of) + model_embeddings = text_encoder.get_input_embeddings() + + for ti_name, ti in ti_list: + assert isinstance(ti, TextualInversionModelRaw) + ti_embedding = _get_ti_embedding(text_encoder.get_input_embeddings(), ti) + + ti_tokens = [] + for i in range(ti_embedding.shape[0]): + embedding = ti_embedding[i] + trigger = _get_trigger(ti_name, i) + + token_id = ti_tokenizer.convert_tokens_to_ids(trigger) + if token_id == ti_tokenizer.unk_token_id: + raise RuntimeError(f"Unable to find token id for token '{trigger}'") + + if model_embeddings.weight.data[token_id].shape != embedding.shape: + raise ValueError( + f"Cannot load embedding for {trigger}. It was trained on a model with token dimension" + f" {embedding.shape[0]}, but the current model has token dimension" + f" {model_embeddings.weight.data[token_id].shape[0]}." + ) + + model_embeddings.weight.data[token_id] = embedding.to( + device=text_encoder.device, dtype=text_encoder.dtype + ) + ti_tokens.append(token_id) + + if len(ti_tokens) > 1: + ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:] + + yield ti_tokenizer, ti_manager + + finally: + if init_tokens_count and new_tokens_added: + text_encoder.resize_token_embeddings(init_tokens_count, pad_to_multiple_of) + + @classmethod + @contextmanager + def apply_clip_skip( + cls, + text_encoder: CLIPTextModel, + clip_skip: int, + ) -> None: + skipped_layers = [] + try: + for _i in range(clip_skip): + skipped_layers.append(text_encoder.text_model.encoder.layers.pop(-1)) + + yield + + finally: + while len(skipped_layers) > 0: + text_encoder.text_model.encoder.layers.append(skipped_layers.pop()) + + @classmethod + @contextmanager + def apply_freeu( + cls, + unet: UNet2DConditionModel, + freeu_config: Optional[FreeUConfig] = None, + ) -> None: + did_apply_freeu = False + try: + assert hasattr(unet, "enable_freeu") # mypy doesn't pick up this attribute? + if freeu_config is not None: + unet.enable_freeu(b1=freeu_config.b1, b2=freeu_config.b2, s1=freeu_config.s1, s2=freeu_config.s2) + did_apply_freeu = True + + yield + + finally: + assert hasattr(unet, "disable_freeu") # mypy doesn't pick up this attribute? + if did_apply_freeu: + unet.disable_freeu() + + +class ONNXModelPatcher: + @classmethod + @contextmanager + def apply_lora_unet( + cls, + unet: OnnxRuntimeModel, + loras: Iterator[Tuple[LoRAModelRaw, float]], + ) -> None: + with cls.apply_lora(unet, loras, "lora_unet_"): + yield + + @classmethod + @contextmanager + def apply_lora_text_encoder( + cls, + text_encoder: OnnxRuntimeModel, + loras: List[Tuple[LoRAModelRaw, float]], + ) -> None: + with cls.apply_lora(text_encoder, loras, "lora_te_"): + yield + + # based on + # https://github.com/ssube/onnx-web/blob/ca2e436f0623e18b4cfe8a0363fcfcf10508acf7/api/onnx_web/convert/diffusion/lora.py#L323 + @classmethod + @contextmanager + def apply_lora( + cls, + model: IAIOnnxRuntimeModel, + loras: List[Tuple[LoRAModelRaw, float]], + prefix: str, + ) -> None: + from .models.base import IAIOnnxRuntimeModel + + if not isinstance(model, IAIOnnxRuntimeModel): + raise Exception("Only IAIOnnxRuntimeModel models supported") + + orig_weights = {} + + try: + blended_loras: Dict[str, torch.Tensor] = {} + + for lora, lora_weight in loras: + for layer_key, layer in lora.layers.items(): + if not layer_key.startswith(prefix): + continue + + layer.to(dtype=torch.float32) + layer_key = layer_key.replace(prefix, "") + # TODO: rewrite to pass original tensor weight(required by ia3) + layer_weight = layer.get_weight(None).detach().cpu().numpy() * lora_weight + if layer_key in blended_loras: + blended_loras[layer_key] += layer_weight + else: + blended_loras[layer_key] = layer_weight + + node_names = {} + for node in model.nodes.values(): + node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name + + for layer_key, lora_weight in blended_loras.items(): + conv_key = layer_key + "_Conv" + gemm_key = layer_key + "_Gemm" + matmul_key = layer_key + "_MatMul" + + if conv_key in node_names or gemm_key in node_names: + if conv_key in node_names: + conv_node = model.nodes[node_names[conv_key]] + else: + conv_node = model.nodes[node_names[gemm_key]] + + weight_name = [n for n in conv_node.input if ".weight" in n][0] + orig_weight = model.tensors[weight_name] + + if orig_weight.shape[-2:] == (1, 1): + if lora_weight.shape[-2:] == (1, 1): + new_weight = orig_weight.squeeze((3, 2)) + lora_weight.squeeze((3, 2)) + else: + new_weight = orig_weight.squeeze((3, 2)) + lora_weight + + new_weight = np.expand_dims(new_weight, (2, 3)) + else: + if orig_weight.shape != lora_weight.shape: + new_weight = orig_weight + lora_weight.reshape(orig_weight.shape) + else: + new_weight = orig_weight + lora_weight + + orig_weights[weight_name] = orig_weight + model.tensors[weight_name] = new_weight.astype(orig_weight.dtype) + + elif matmul_key in node_names: + weight_node = model.nodes[node_names[matmul_key]] + matmul_name = [n for n in weight_node.input if "MatMul" in n][0] + + orig_weight = model.tensors[matmul_name] + new_weight = orig_weight + lora_weight.transpose() + + orig_weights[matmul_name] = orig_weight + model.tensors[matmul_name] = new_weight.astype(orig_weight.dtype) + + else: + # warn? err? + pass + + yield + + finally: + # restore original weights + for name, orig_weight in orig_weights.items(): + model.tensors[name] = orig_weight + + @classmethod + @contextmanager + def apply_ti( + cls, + tokenizer: CLIPTokenizer, + text_encoder: IAIOnnxRuntimeModel, + ti_list: List[Tuple[str, Any]], + ) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]: + from .models.base import IAIOnnxRuntimeModel + + if not isinstance(text_encoder, IAIOnnxRuntimeModel): + raise Exception("Only IAIOnnxRuntimeModel models supported") + + orig_embeddings = None + + try: + # HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a + # workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after + # exiting this `apply_ti(...)` context manager. + # + # In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`, + # but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs). + ti_tokenizer = pickle.loads(pickle.dumps(tokenizer)) + ti_manager = TextualInversionManager(ti_tokenizer) + + def _get_trigger(ti_name: str, index: int) -> str: + trigger = ti_name + if index > 0: + trigger += f"-!pad-{i}" + return f"<{trigger}>" + + # modify text_encoder + orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"] + + # modify tokenizer + new_tokens_added = 0 + for ti_name, ti in ti_list: + if ti.embedding_2 is not None: + ti_embedding = ( + ti.embedding_2 if ti.embedding_2.shape[1] == orig_embeddings.shape[0] else ti.embedding + ) + else: + ti_embedding = ti.embedding + + for i in range(ti_embedding.shape[0]): + new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i)) + + embeddings = np.concatenate( + (np.copy(orig_embeddings), np.zeros((new_tokens_added, orig_embeddings.shape[1]))), + axis=0, + ) + + for ti_name, _ in ti_list: + ti_tokens = [] + for i in range(ti_embedding.shape[0]): + embedding = ti_embedding[i].detach().numpy() + trigger = _get_trigger(ti_name, i) + + token_id = ti_tokenizer.convert_tokens_to_ids(trigger) + if token_id == ti_tokenizer.unk_token_id: + raise RuntimeError(f"Unable to find token id for token '{trigger}'") + + if embeddings[token_id].shape != embedding.shape: + raise ValueError( + f"Cannot load embedding for {trigger}. It was trained on a model with token dimension" + f" {embedding.shape[0]}, but the current model has token dimension" + f" {embeddings[token_id].shape[0]}." + ) + + embeddings[token_id] = embedding + ti_tokens.append(token_id) + + if len(ti_tokens) > 1: + ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:] + + text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = embeddings.astype( + orig_embeddings.dtype + ) + + yield ti_tokenizer, ti_manager + + finally: + # restore + if orig_embeddings is not None: + text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = orig_embeddings diff --git a/invokeai/backend/embeddings/textual_inversion.py b/invokeai/backend/embeddings/textual_inversion.py new file mode 100644 index 00000000000..389edff039d --- /dev/null +++ b/invokeai/backend/embeddings/textual_inversion.py @@ -0,0 +1,100 @@ +"""Textual Inversion wrapper class.""" + +from pathlib import Path +from typing import Dict, List, Optional, Union + +import torch +from compel.embeddings_provider import BaseTextualInversionManager +from safetensors.torch import load_file +from transformers import CLIPTokenizer +from typing_extensions import Self + +from .embedding_base import EmbeddingModelRaw + + +class TextualInversionModelRaw(EmbeddingModelRaw): + embedding: torch.Tensor # [n, 768]|[n, 1280] + embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models + + @classmethod + def from_checkpoint( + cls, + file_path: Union[str, Path], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> Self: + if not isinstance(file_path, Path): + file_path = Path(file_path) + + result = cls() # TODO: + + if file_path.suffix == ".safetensors": + state_dict = load_file(file_path.absolute().as_posix(), device="cpu") + else: + state_dict = torch.load(file_path, map_location="cpu") + + # both v1 and v2 format embeddings + # difference mostly in metadata + if "string_to_param" in state_dict: + if len(state_dict["string_to_param"]) > 1: + print( + f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first', + " token will be used.", + ) + + result.embedding = next(iter(state_dict["string_to_param"].values())) + + # v3 (easynegative) + elif "emb_params" in state_dict: + result.embedding = state_dict["emb_params"] + + # v5(sdxl safetensors file) + elif "clip_g" in state_dict and "clip_l" in state_dict: + result.embedding = state_dict["clip_g"] + result.embedding_2 = state_dict["clip_l"] + + # v4(diffusers bin files) + else: + result.embedding = next(iter(state_dict.values())) + + if len(result.embedding.shape) == 1: + result.embedding = result.embedding.unsqueeze(0) + + if not isinstance(result.embedding, torch.Tensor): + raise ValueError(f"Invalid embeddings file: {file_path.name}") + + return result + + +# no type hints for BaseTextualInversionManager? +class TextualInversionManager(BaseTextualInversionManager): # type: ignore + pad_tokens: Dict[int, List[int]] + tokenizer: CLIPTokenizer + + def __init__(self, tokenizer: CLIPTokenizer): + self.pad_tokens = {} + self.tokenizer = tokenizer + + def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]: + if len(self.pad_tokens) == 0: + return token_ids + + if token_ids[0] == self.tokenizer.bos_token_id: + raise ValueError("token_ids must not start with bos_token_id") + if token_ids[-1] == self.tokenizer.eos_token_id: + raise ValueError("token_ids must not end with eos_token_id") + + new_token_ids = [] + for token_id in token_ids: + new_token_ids.append(token_id) + if token_id in self.pad_tokens: + new_token_ids.extend(self.pad_tokens[token_id]) + + # Do not exceed the max model input size + # The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(), + # which first removes and then adds back the start and end tokens. + max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2 + if len(new_token_ids) > max_length: + new_token_ids = new_token_ids[0:max_length] + + return new_token_ids diff --git a/invokeai/backend/image_util/safety_checker.py b/invokeai/backend/image_util/safety_checker.py index b9649925e14..92ddef5ecc3 100644 --- a/invokeai/backend/image_util/safety_checker.py +++ b/invokeai/backend/image_util/safety_checker.py @@ -8,8 +8,8 @@ import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend import SilenceWarnings from invokeai.backend.util.devices import choose_torch_device +from invokeai.backend.util.silence_warnings import SilenceWarnings config = InvokeAIAppConfig.get_config() diff --git a/invokeai/backend/install/install_helper.py b/invokeai/backend/install/install_helper.py index e54be527d95..9c386c209ce 100644 --- a/invokeai/backend/install/install_helper.py +++ b/invokeai/backend/install/install_helper.py @@ -30,13 +30,14 @@ from invokeai.backend.model_manager import ( BaseModelType, InvalidModelConfigException, + ModelRepoVariant, ModelType, ) from invokeai.backend.model_manager.metadata import UnknownMetadataException from invokeai.backend.util.logging import InvokeAILogger # name of the starter models file -INITIAL_MODELS = "INITIAL_MODELS2.yaml" +INITIAL_MODELS = "INITIAL_MODELS.yaml" def initialize_record_store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase: @@ -98,11 +99,13 @@ def __init__(self) -> None: super().__init__() self._bars: Dict[str, tqdm] = {} self._last: Dict[str, int] = {} + self._logger = InvokeAILogger.get_logger(__name__) def dispatch(self, event_name: str, payload: Any) -> None: """Dispatch an event by appending it to self.events.""" + data = payload["data"] + source = data["source"] if payload["event"] == "model_install_downloading": - data = payload["data"] dest = data["local_path"] total_bytes = data["total_bytes"] bytes = data["bytes"] @@ -111,6 +114,12 @@ def dispatch(self, event_name: str, payload: Any) -> None: self._last[dest] = 0 self._bars[dest].update(bytes - self._last[dest]) self._last[dest] = bytes + elif payload["event"] == "model_install_completed": + self._logger.info(f"{source}: installed successfully.") + elif payload["event"] == "model_install_error": + self._logger.warning(f"{source}: installation failed with error {data['error']}") + elif payload["event"] == "model_install_cancelled": + self._logger.warning(f"{source}: installation cancelled") class InstallHelper(object): @@ -225,11 +234,19 @@ def _make_install_source(self, model_info: UnifiedModelInfo) -> ModelSource: if model_path.exists(): # local file on disk return LocalModelSource(path=model_path.absolute(), inplace=True) - if re.match(r"^[^/]+/[^/]+$", model_path_id_or_url): # hugging face repo_id + + # parsing huggingface repo ids + # we're going to do a little trick that allows for extended repo_ids of form "foo/bar:fp16" + variants = "|".join([x.lower() for x in ModelRepoVariant.__members__]) + if match := re.match(f"^([^/]+/[^/]+?)(?::({variants}))?$", model_path_id_or_url): + repo_id = match.group(1) + repo_variant = ModelRepoVariant(match.group(2)) if match.group(2) else None + subfolder = Path(model_info.subfolder) if model_info.subfolder else None return HFModelSource( - repo_id=model_path_id_or_url, + repo_id=repo_id, access_token=HfFolder.get_token(), - subfolder=model_info.subfolder, + subfolder=subfolder, + variant=repo_variant, ) if re.match(r"^(http|https):", model_path_id_or_url): return URLModelSource(url=AnyHttpUrl(model_path_id_or_url)) @@ -270,9 +287,11 @@ def add_or_delete(self, selections: InstallSelections) -> None: model_name=model_name, ) if len(matches) > 1: - print(f"{model} is ambiguous. Please use model_type:model_name (e.g. main:my_model) to disambiguate.") + print( + f"{model_to_remove} is ambiguous. Please use model_base/model_type/model_name (e.g. sd-1/main/my_model) to disambiguate." + ) elif not matches: - print(f"{model}: unknown model") + print(f"{model_to_remove}: unknown model") else: for m in matches: print(f"Deleting {m.type}:{m.name}") diff --git a/invokeai/backend/install/invokeai_configure.py b/invokeai/backend/install/invokeai_configure.py index 3cb7db6c82c..4dfa2b070c0 100755 --- a/invokeai/backend/install/invokeai_configure.py +++ b/invokeai/backend/install/invokeai_configure.py @@ -18,31 +18,30 @@ from enum import Enum from pathlib import Path from shutil import get_terminal_size -from typing import Any, get_args, get_type_hints +from typing import Any, Optional, Set, Tuple, Type, get_args, get_type_hints from urllib import request import npyscreen -import omegaconf import psutil import torch import transformers -import yaml -from diffusers import AutoencoderKL +from diffusers import AutoencoderKL, ModelMixin from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from huggingface_hub import HfFolder from huggingface_hub import login as hf_hub_login -from omegaconf import OmegaConf -from pydantic import ValidationError +from omegaconf import DictConfig, OmegaConf +from pydantic.error_wrappers import ValidationError from tqdm import tqdm from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextConfig, CLIPTextModel, CLIPTokenizer import invokeai.configs as configs from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend.install.install_helper import InstallHelper, InstallSelections from invokeai.backend.install.legacy_arg_parsing import legacy_parser -from invokeai.backend.install.model_install_backend import InstallSelections, ModelInstall, hf_download_from_pretrained -from invokeai.backend.model_management.model_probe import BaseModelType, ModelType +from invokeai.backend.model_manager import BaseModelType, ModelType +from invokeai.backend.util import choose_precision, choose_torch_device from invokeai.backend.util.logging import InvokeAILogger -from invokeai.frontend.install.model_install import addModelsForm, process_and_execute +from invokeai.frontend.install.model_install import addModelsForm # TO DO - Move all the frontend code into invokeai.frontend.install from invokeai.frontend.install.widgets import ( @@ -61,7 +60,7 @@ transformers.logging.set_verbosity_error() -def get_literal_fields(field) -> list[Any]: +def get_literal_fields(field: str) -> Tuple[Any]: return get_args(get_type_hints(InvokeAIAppConfig).get(field)) @@ -80,8 +79,7 @@ def get_literal_fields(field) -> list[Any]: GENERATION_OPT_CHOICES = ["sequential_guidance", "force_tiled_decode", "lazy_offload"] GB = 1073741824 # GB in bytes HAS_CUDA = torch.cuda.is_available() -_, MAX_VRAM = torch.cuda.mem_get_info() if HAS_CUDA else (0, 0) - +_, MAX_VRAM = torch.cuda.mem_get_info() if HAS_CUDA else (0.0, 0.0) MAX_VRAM /= GB MAX_RAM = psutil.virtual_memory().total / GB @@ -96,13 +94,15 @@ def get_literal_fields(field) -> list[Any]: class DummyWidgetValue(Enum): + """Dummy widget values.""" + zero = 0 true = True false = False # -------------------------------------------- -def postscript(errors: None): +def postscript(errors: Set[str]) -> None: if not any(errors): message = f""" ** INVOKEAI INSTALLATION SUCCESSFUL ** @@ -143,7 +143,7 @@ def yes_or_no(prompt: str, default_yes=True): # --------------------------------------------- -def HfLogin(access_token) -> str: +def HfLogin(access_token) -> None: """ Helper for logging in to Huggingface The stdout capture is needed to hide the irrelevant "git credential helper" warning @@ -162,7 +162,7 @@ def HfLogin(access_token) -> str: # ------------------------------------- class ProgressBar: - def __init__(self, model_name="file"): + def __init__(self, model_name: str = "file"): self.pbar = None self.name = model_name @@ -179,6 +179,22 @@ def __call__(self, block_num, block_size, total_size): self.pbar.update(block_size) +# --------------------------------------------- +def hf_download_from_pretrained(model_class: Type[ModelMixin], model_name: str, destination: Path, **kwargs: Any): + filter = lambda x: "fp16 is not a valid" not in x.getMessage() # noqa E731 + logger.addFilter(filter) + try: + model = model_class.from_pretrained( + model_name, + resume_download=True, + **kwargs, + ) + model.save_pretrained(destination, safe_serialization=True) + finally: + logger.removeFilter(filter) + return destination + + # --------------------------------------------- def download_with_progress_bar(model_url: str, model_dest: str, label: str = "the"): try: @@ -249,6 +265,7 @@ def download_conversion_models(): # --------------------------------------------- +# TO DO: use the download queue here. def download_realesrgan(): logger.info("Installing ESRGAN Upscaling models...") URLs = [ @@ -288,18 +305,19 @@ def download_lama(): # --------------------------------------------- -def download_support_models(): +def download_support_models() -> None: download_realesrgan() download_lama() download_conversion_models() # ------------------------------------- -def get_root(root: str = None) -> str: +def get_root(root: Optional[str] = None) -> str: if root: return root - elif os.environ.get("INVOKEAI_ROOT"): - return os.environ.get("INVOKEAI_ROOT") + elif root := os.environ.get("INVOKEAI_ROOT"): + assert root is not None + return root else: return str(config.root_path) @@ -455,6 +473,25 @@ def create(self): max_width=110, scroll_exit=True, ) + self.add_widget_intelligent( + npyscreen.TitleFixedText, + name="Model disk conversion cache size (GB). This is used to cache safetensors files that need to be converted to diffusers..", + begin_entry_at=0, + editable=False, + color="CONTROL", + scroll_exit=True, + ) + self.nextrely -= 1 + self.disk = self.add_widget_intelligent( + npyscreen.Slider, + value=clip(old_opts.convert_cache, range=(0, 100), step=0.5), + out_of=100, + lowest=0.0, + step=0.5, + relx=8, + scroll_exit=True, + ) + self.nextrely += 1 self.add_widget_intelligent( npyscreen.TitleFixedText, name="Model RAM cache size (GB). Make this at least large enough to hold a single full model (2GB for SD-1, 6GB for SDXL).", @@ -495,6 +532,14 @@ def create(self): ) else: self.vram = DummyWidgetValue.zero + + self.nextrely += 1 + self.add_widget_intelligent( + npyscreen.FixedText, + value="Location of the database used to store model path and configuration information:", + editable=False, + color="CONTROL", + ) self.nextrely += 1 self.outdir = self.add_widget_intelligent( FileBox, @@ -506,19 +551,21 @@ def create(self): labelColor="GOOD", begin_entry_at=40, max_height=3, + max_width=127, scroll_exit=True, ) self.autoimport_dirs = {} self.autoimport_dirs["autoimport_dir"] = self.add_widget_intelligent( FileBox, - name="Folder to recursively scan for new checkpoints, ControlNets, LoRAs and TI models", - value=str(config.root_path / config.autoimport_dir), + name="Optional folder to scan for new checkpoints, ControlNets, LoRAs and TI models", + value=str(config.root_path / config.autoimport_dir) if config.autoimport_dir else "", select_dir=True, must_exist=False, use_two_lines=False, labelColor="GOOD", begin_entry_at=32, max_height=3, + max_width=127, scroll_exit=True, ) self.nextrely += 1 @@ -555,6 +602,10 @@ def show_hide_slice_sizes(self, value): self.attention_slice_label.hidden = not show self.attention_slice_size.hidden = not show + def show_hide_model_conf_override(self, value): + self.model_conf_override.hidden = value + self.model_conf_override.display() + def on_ok(self): options = self.marshall_arguments() if self.validate_field_values(options): @@ -584,18 +635,21 @@ def validate_field_values(self, opt: Namespace) -> bool: else: return True - def marshall_arguments(self): + def marshall_arguments(self) -> Namespace: new_opts = Namespace() for attr in [ "ram", "vram", + "convert_cache", "outdir", ]: if hasattr(self, attr): setattr(new_opts, attr, getattr(self, attr).value) for attr in self.autoimport_dirs: + if not self.autoimport_dirs[attr].value: + continue directory = Path(self.autoimport_dirs[attr].value) if directory.is_relative_to(config.root_path): directory = directory.relative_to(config.root_path) @@ -615,13 +669,14 @@ def marshall_arguments(self): class EditOptApplication(npyscreen.NPSAppManaged): - def __init__(self, program_opts: Namespace, invokeai_opts: Namespace): + def __init__(self, program_opts: Namespace, invokeai_opts: InvokeAIAppConfig, install_helper: InstallHelper): super().__init__() self.program_opts = program_opts self.invokeai_opts = invokeai_opts self.user_cancelled = False self.autoload_pending = True - self.install_selections = default_user_selections(program_opts) + self.install_helper = install_helper + self.install_selections = default_user_selections(program_opts, install_helper) def onStart(self): npyscreen.setTheme(npyscreen.Themes.DefaultTheme) @@ -640,16 +695,10 @@ def onStart(self): cycle_widgets=False, ) - def new_opts(self): + def new_opts(self) -> Namespace: return self.options.marshall_arguments() -def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Namespace: - editApp = EditOptApplication(program_opts, invokeai_opts) - editApp.run() - return editApp.new_opts() - - def default_ramcache() -> float: """Run a heuristic for the default RAM cache based on installed RAM.""" @@ -660,27 +709,18 @@ def default_ramcache() -> float: ) # 2.1 is just large enough for sd 1.5 ;-) -def default_startup_options(init_file: Path) -> Namespace: +def default_startup_options(init_file: Path) -> InvokeAIAppConfig: opts = InvokeAIAppConfig.get_config() - opts.ram = opts.ram or default_ramcache() + opts.ram = default_ramcache() return opts -def default_user_selections(program_opts: Namespace) -> InstallSelections: - try: - installer = ModelInstall(config) - except omegaconf.errors.ConfigKeyError: - logger.warning("Your models.yaml file is corrupt or out of date. Reinitializing") - initialize_rootdir(config.root_path, True) - installer = ModelInstall(config) - - models = installer.all_models() +def default_user_selections(program_opts: Namespace, install_helper: InstallHelper) -> InstallSelections: + default_model = install_helper.default_model() + assert default_model is not None + default_models = [default_model] if program_opts.default_only else install_helper.recommended_models() return InstallSelections( - install_models=[models[installer.default_model()].path or models[installer.default_model()].repo_id] - if program_opts.default_only - else [models[x].path or models[x].repo_id for x in installer.recommended_models()] - if program_opts.yes_to_all - else [], + install_models=default_models if program_opts.yes_to_all else [], ) @@ -716,21 +756,10 @@ def initialize_rootdir(root: Path, yes_to_all: bool = False): path.mkdir(parents=True, exist_ok=True) -def maybe_create_models_yaml(root: Path): - models_yaml = root / "configs" / "models.yaml" - if models_yaml.exists(): - if OmegaConf.load(models_yaml).get("__metadata__"): # up to date - return - else: - logger.info("Creating new models.yaml, original saved as models.yaml.orig") - models_yaml.rename(models_yaml.parent / "models.yaml.orig") - - with open(models_yaml, "w") as yaml_file: - yaml_file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}})) - - # ------------------------------------- -def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace, Namespace): +def run_console_ui( + program_opts: Namespace, initfile: Path, install_helper: InstallHelper +) -> Tuple[Optional[Namespace], Optional[InstallSelections]]: invokeai_opts = default_startup_options(initfile) invokeai_opts.root = program_opts.root @@ -739,22 +768,16 @@ def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace "Could not increase terminal size. Try running again with a larger window or smaller font size." ) - # the install-models application spawns a subprocess to install - # models, and will crash unless this is set before running. - import torch - - torch.multiprocessing.set_start_method("spawn") - - editApp = EditOptApplication(program_opts, invokeai_opts) + editApp = EditOptApplication(program_opts, invokeai_opts, install_helper) editApp.run() if editApp.user_cancelled: return (None, None) else: - return (editApp.new_opts, editApp.install_selections) + return (editApp.new_opts(), editApp.install_selections) # ------------------------------------- -def write_opts(opts: Namespace, init_file: Path): +def write_opts(opts: InvokeAIAppConfig, init_file: Path) -> None: """ Update the invokeai.yaml file with values from current settings. """ @@ -762,7 +785,7 @@ def write_opts(opts: Namespace, init_file: Path): new_config = InvokeAIAppConfig.get_config() new_config.root = config.root - for key, value in opts.__dict__.items(): + for key, value in opts.model_dump().items(): if hasattr(new_config, key): setattr(new_config, key, value) @@ -779,7 +802,7 @@ def default_output_dir() -> Path: # ------------------------------------- -def write_default_options(program_opts: Namespace, initfile: Path): +def write_default_options(program_opts: Namespace, initfile: Path) -> None: opt = default_startup_options(initfile) write_opts(opt, initfile) @@ -789,16 +812,11 @@ def write_default_options(program_opts: Namespace, initfile: Path): # the legacy Args object in order to parse # the old init file and write out the new # yaml format. -def migrate_init_file(legacy_format: Path): +def migrate_init_file(legacy_format: Path) -> None: old = legacy_parser.parse_args([f"@{str(legacy_format)}"]) new = InvokeAIAppConfig.get_config() - fields = [ - x - for x, y in InvokeAIAppConfig.model_fields.items() - if (y.json_schema_extra.get("category", None) if y.json_schema_extra else None) != "DEPRECATED" - ] - for attr in fields: + for attr in InvokeAIAppConfig.model_fields.keys(): if hasattr(old, attr): try: setattr(new, attr, getattr(old, attr)) @@ -819,7 +837,7 @@ def migrate_init_file(legacy_format: Path): # ------------------------------------- -def migrate_models(root: Path): +def migrate_models(root: Path) -> None: from invokeai.backend.install.migrate_to_3 import do_migrate do_migrate(root, root) @@ -838,7 +856,9 @@ def migrate_if_needed(opt: Namespace, root: Path) -> bool: ): logger.info("** Migrating invokeai.init to invokeai.yaml") migrate_init_file(old_init_file) - config.parse_args(argv=[], conf=OmegaConf.load(new_init_file)) + omegaconf = OmegaConf.load(new_init_file) + assert isinstance(omegaconf, DictConfig) + config.parse_args(argv=[], conf=omegaconf) if old_hub.exists(): migrate_models(config.root_path) @@ -849,7 +869,7 @@ def migrate_if_needed(opt: Namespace, root: Path) -> bool: # ------------------------------------- -def main() -> None: +def main(): parser = argparse.ArgumentParser(description="InvokeAI model downloader") parser.add_argument( "--skip-sd-weights", @@ -908,6 +928,7 @@ def main() -> None: if opt.full_precision: invoke_args.extend(["--precision", "float32"]) config.parse_args(invoke_args) + config.precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device())) logger = InvokeAILogger().get_logger(config=config) errors = set() @@ -921,14 +942,18 @@ def main() -> None: # run this unconditionally in case new directories need to be added initialize_rootdir(config.root_path, opt.yes_to_all) - models_to_download = default_user_selections(opt) + # this will initialize the models.yaml file if not present + install_helper = InstallHelper(config, logger) + + models_to_download = default_user_selections(opt, install_helper) new_init_file = config.root_path / "invokeai.yaml" if opt.yes_to_all: write_default_options(opt, new_init_file) init_options = Namespace(precision="float32" if opt.full_precision else "float16") + else: - init_options, models_to_download = run_console_ui(opt, new_init_file) + init_options, models_to_download = run_console_ui(opt, new_init_file, install_helper) if init_options: write_opts(init_options, new_init_file) else: @@ -943,10 +968,12 @@ def main() -> None: if opt.skip_sd_weights: logger.warning("Skipping diffusion weights download per user request") + elif models_to_download: - process_and_execute(opt, models_to_download) + install_helper.add_or_delete(models_to_download) postscript(errors=errors) + if not opt.yes_to_all: input("Press any key to continue...") except WindowTooSmallException as e: diff --git a/invokeai/backend/ip_adapter/ip_adapter.py b/invokeai/backend/ip_adapter/ip_adapter.py index 9176bf1f49f..b4706ea99c0 100644 --- a/invokeai/backend/ip_adapter/ip_adapter.py +++ b/invokeai/backend/ip_adapter/ip_adapter.py @@ -8,7 +8,6 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights -from invokeai.backend.model_management.models.base import calc_model_size_by_data from .resampler import Resampler @@ -124,6 +123,9 @@ def to(self, device: torch.device, dtype: Optional[torch.dtype] = None): self.attn_weights.to(device=self.device, dtype=self.dtype) def calc_size(self): + # workaround for circular import + from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data + return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights) def _init_image_proj_model(self, state_dict): diff --git a/invokeai/backend/model_management/__init__.py b/invokeai/backend/model_management/__init__.py index 03abf58eb46..d523a7a0c8d 100644 --- a/invokeai/backend/model_management/__init__.py +++ b/invokeai/backend/model_management/__init__.py @@ -3,7 +3,7 @@ Initialization file for invokeai.backend.model_management """ # This import must be first -from .model_manager import AddModelResult, ModelInfo, ModelManager, SchedulerPredictionType +from .model_manager import AddModelResult, LoadedModelInfo, ModelManager, SchedulerPredictionType from .lora import ModelPatcher, ONNXModelPatcher from .model_cache import ModelCache diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index d72f55794d3..aed5eb60d57 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -102,7 +102,7 @@ def apply_sdxl_lora_text_encoder2( def apply_lora( cls, model: torch.nn.Module, - loras: List[Tuple[LoRAModel, float]], + loras: List[Tuple[LoRAModel, float]], # THIS IS INCORRECT. IT IS ACTUALLY A LoRAModelRaw prefix: str, ): original_weights = {} @@ -194,6 +194,8 @@ def _get_trigger(ti_name, index): return f"<{trigger}>" def _get_ti_embedding(model_embeddings, ti): + print(f"DEBUG: model_embeddings={type(model_embeddings)}, ti={type(ti)}") + print(f"DEBUG: is it an nn.Module? {isinstance(model_embeddings, torch.nn.Module)}") # for SDXL models, select the embedding that matches the text encoder's dimensions if ti.embedding_2 is not None: return ( @@ -202,6 +204,7 @@ def _get_ti_embedding(model_embeddings, ti): else ti.embedding ) else: + print(f"DEBUG: ti.embedding={type(ti.embedding)}") return ti.embedding # modify tokenizer diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 362d8d3ff55..84d93f15fa8 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -271,7 +271,7 @@ class ConfigMeta(BaseModel):Loras, textual_inversion and controlnet models are n @dataclass -class ModelInfo: +class LoadedModelInfo: context: ModelLocker name: str base_model: BaseModelType @@ -450,7 +450,7 @@ def get_model( base_model: BaseModelType, model_type: ModelType, submodel_type: Optional[SubModelType] = None, - ) -> ModelInfo: + ) -> LoadedModelInfo: """Given a model named identified in models.yaml, return an ModelInfo object describing it. :param model_name: symbolic name of the model in models.yaml @@ -499,7 +499,7 @@ def get_model( model_class=model_class, base_model=base_model, model_type=model_type, - submodel=submodel_type, + submodel_type=submodel_type, ) if model_key not in self.cache_keys: @@ -508,7 +508,7 @@ def get_model( model_hash = "" # TODO: - return ModelInfo( + return LoadedModelInfo( context=model_context, name=model_name, base_model=base_model, diff --git a/invokeai/backend/model_management/models/controlnet.py b/invokeai/backend/model_management/models/controlnet.py index da269eba4b7..3b534cb9d14 100644 --- a/invokeai/backend/model_management/models/controlnet.py +++ b/invokeai/backend/model_management/models/controlnet.py @@ -139,7 +139,6 @@ def _convert_controlnet_ckpt_and_cache( cache it to disk, and return Path to converted file. If already on disk then just returns Path. """ - print(f"DEBUG: controlnet config = {model_config}") app_config = InvokeAIAppConfig.get_config() weights = app_config.root_path / model_path output_path = Path(output_path) diff --git a/invokeai/backend/model_manager/__init__.py b/invokeai/backend/model_manager/__init__.py index 0f16852c934..98cc5054c73 100644 --- a/invokeai/backend/model_manager/__init__.py +++ b/invokeai/backend/model_manager/__init__.py @@ -1,6 +1,7 @@ """Re-export frequently-used symbols from the Model Manager backend.""" from .config import ( + AnyModel, AnyModelConfig, BaseModelType, InvalidModelConfigException, @@ -12,14 +13,17 @@ SchedulerPredictionType, SubModelType, ) +from .load import LoadedModel from .probe import ModelProbe from .search import ModelSearch __all__ = [ + "AnyModel", "AnyModelConfig", "BaseModelType", "ModelRepoVariant", "InvalidModelConfigException", + "LoadedModel", "ModelConfigFactory", "ModelFormat", "ModelProbe", diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 964cc19f196..42921f0b32c 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -19,12 +19,22 @@ Validation errors will raise an InvalidModelConfigException error. """ +import time from enum import Enum from typing import Literal, Optional, Type, Union +import torch +from diffusers import ModelMixin from pydantic import BaseModel, ConfigDict, Field, TypeAdapter from typing_extensions import Annotated, Any, Dict +from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel + +from ..embeddings.embedding_base import EmbeddingModelRaw +from ..ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus + +AnyModel = Union[ModelMixin, torch.nn.Module, IAIOnnxRuntimeModel, IPAdapter, IPAdapterPlus, EmbeddingModelRaw] + class InvalidModelConfigException(Exception): """Exception for when config parser doesn't recognized this combination of model type and format.""" @@ -102,7 +112,7 @@ class SchedulerPredictionType(str, Enum): class ModelRepoVariant(str, Enum): """Various hugging face variants on the diffusers format.""" - DEFAULT = "default" # model files without "fp16" or other qualifier + DEFAULT = "" # model files without "fp16" or other qualifier - empty str FP16 = "fp16" FP32 = "fp32" ONNX = "onnx" @@ -113,11 +123,11 @@ class ModelRepoVariant(str, Enum): class ModelConfigBase(BaseModel): """Base class for model configuration information.""" - path: str - name: str - base: BaseModelType - type: ModelType - format: ModelFormat + path: str = Field(description="filesystem path to the model file or directory") + name: str = Field(description="model name") + base: BaseModelType = Field(description="base model") + type: ModelType = Field(description="type of the model") + format: ModelFormat = Field(description="model format") key: str = Field(description="unique key for model", default="") original_hash: Optional[str] = Field( description="original fasthash of model contents", default=None @@ -125,8 +135,9 @@ class ModelConfigBase(BaseModel): current_hash: Optional[str] = Field( description="current fasthash of model contents", default=None ) # if model is converted or otherwise modified, this will hold updated hash - description: Optional[str] = Field(default=None) - source: Optional[str] = Field(description="Model download source (URL or repo_id)", default=None) + description: Optional[str] = Field(description="human readable description of the model", default=None) + source: Optional[str] = Field(description="model original source (path, URL or repo_id)", default=None) + last_modified: Optional[float] = Field(description="timestamp for modification time", default_factory=time.time) model_config = ConfigDict( use_enum_values=False, @@ -150,6 +161,7 @@ class _DiffusersConfig(ModelConfigBase): """Model config for diffusers-style models.""" format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers + repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT class LoRAConfig(ModelConfigBase): @@ -199,6 +211,8 @@ class _MainConfig(ModelConfigBase): vae: Optional[str] = Field(default=None) variant: ModelVariantType = ModelVariantType.Normal + prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon + upcast_attention: bool = False ztsnr_training: bool = False @@ -212,8 +226,6 @@ class MainDiffusersConfig(_DiffusersConfig, _MainConfig): """Model config for main diffusers models.""" type: Literal[ModelType.Main] = ModelType.Main - prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon - upcast_attention: bool = False class ONNXSD1Config(_MainConfig): @@ -237,10 +249,21 @@ class ONNXSD2Config(_MainConfig): upcast_attention: bool = True +class ONNXSDXLConfig(_MainConfig): + """Model config for ONNX format models based on sdxl.""" + + type: Literal[ModelType.ONNX] = ModelType.ONNX + format: Literal[ModelFormat.Onnx, ModelFormat.Olive] + # No yaml config file for ONNX, so these are part of config + base: Literal[BaseModelType.StableDiffusionXL] = BaseModelType.StableDiffusionXL + prediction_type: SchedulerPredictionType = SchedulerPredictionType.VPrediction + + class IPAdapterConfig(ModelConfigBase): """Model config for IP Adaptor format models.""" type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter + image_encoder_model_id: str format: Literal[ModelFormat.InvokeAI] @@ -258,7 +281,7 @@ class T2IConfig(ModelConfigBase): format: Literal[ModelFormat.Diffusers] -_ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator="base")] +_ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config, ONNXSDXLConfig], Field(discriminator="base")] _ControlNetConfig = Annotated[ Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig], Field(discriminator="format"), @@ -271,6 +294,7 @@ class T2IConfig(ModelConfigBase): _ONNXConfig, _VaeConfig, _ControlNetConfig, + # ModelConfigBase, LoRAConfig, TextualInversionConfig, IPAdapterConfig, @@ -280,6 +304,7 @@ class T2IConfig(ModelConfigBase): AnyModelConfigValidator = TypeAdapter(AnyModelConfig) + # IMPLEMENTATION NOTE: # The preferred alternative to the above is a discriminated Union as shown # below. However, it breaks FastAPI when used as the input Body parameter in a route. @@ -308,9 +333,10 @@ class ModelConfigFactory(object): @classmethod def make_config( cls, - model_data: Union[dict, AnyModelConfig], + model_data: Union[Dict[str, Any], AnyModelConfig], key: Optional[str] = None, - dest_class: Optional[Type] = None, + dest_class: Optional[Type[ModelConfigBase]] = None, + timestamp: Optional[float] = None, ) -> AnyModelConfig: """ Return the appropriate config object from raw dict values. @@ -321,12 +347,17 @@ def make_config( :param dest_class: The config class to be returned. If not provided, will be selected automatically. """ + model: Optional[ModelConfigBase] = None if isinstance(model_data, ModelConfigBase): model = model_data elif dest_class: - model = dest_class.validate_python(model_data) + model = dest_class.model_validate(model_data) else: - model = AnyModelConfigValidator.validate_python(model_data) + # mypy doesn't typecheck TypeAdapters well? + model = AnyModelConfigValidator.validate_python(model_data) # type: ignore + assert model is not None if key: model.key = key - return model + if timestamp: + model.last_modified = timestamp + return model # type: ignore diff --git a/invokeai/backend/model_manager/convert_ckpt_to_diffusers.py b/invokeai/backend/model_manager/convert_ckpt_to_diffusers.py new file mode 100644 index 00000000000..6f5acd58329 --- /dev/null +++ b/invokeai/backend/model_manager/convert_ckpt_to_diffusers.py @@ -0,0 +1,1742 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Adapted for use in InvokeAI by Lincoln Stein, July 2023 +# +""" Conversion script for the Stable Diffusion checkpoints.""" + +import re +from contextlib import nullcontext +from io import BytesIO +from pathlib import Path +from typing import Optional, Union + +import requests +import torch +from diffusers.models import AutoencoderKL, ControlNetModel, PriorTransformer, UNet2DConditionModel +from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel +from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer +from diffusers.schedulers import ( + DDIMScheduler, + DDPMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + UnCLIPScheduler, +) +from diffusers.utils import is_accelerate_available +from picklescan.scanner import scan_file_path +from transformers import ( + AutoFeatureExtractor, + BertTokenizerFast, + CLIPImageProcessor, + CLIPTextConfig, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionConfig, + CLIPVisionModelWithProjection, +) + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend.model_manager import BaseModelType, ModelVariantType +from invokeai.backend.util.logging import InvokeAILogger + +try: + from omegaconf import OmegaConf + from omegaconf.dictconfig import DictConfig +except ImportError: + raise ImportError( + "OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`." + ) + +if is_accelerate_available(): + from accelerate import init_empty_weights + from accelerate.utils import set_module_tensor_to_device + +logger = InvokeAILogger.get_logger(__name__) +CONVERT_MODEL_ROOT = InvokeAIAppConfig.get_config().models_path / "core/convert" + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("q.weight", "to_q.weight") + new_item = new_item.replace("q.bias", "to_q.bias") + + new_item = new_item.replace("k.weight", "to_k.weight") + new_item = new_item.replace("k.bias", "to_k.bias") + + new_item = new_item.replace("v.weight", "to_v.weight") + new_item = new_item.replace("v.bias", "to_v.bias") + + new_item = new_item.replace("proj_out.weight", "to_out.0.weight") + new_item = new_item.replace("proj_out.bias", "to_out.0.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits + attention layers, and takes into account additional replacements that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path) + shape = old_checkpoint[path["old"]].shape + if is_attn_weight and len(shape) == 3: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + elif is_attn_weight and len(shape) == 4: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + if controlnet: + unet_params = original_config.model.params.control_stage_config.params + else: + if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None: + unet_params = original_config.model.params.unet_config.params + else: + unet_params = original_config.model.params.network_config.params + + vae_params = original_config.model.params.first_stage_config.params.ddconfig + + block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for _i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + if unet_params.transformer_depth is not None: + transformer_layers_per_block = ( + unet_params.transformer_depth + if isinstance(unet_params.transformer_depth, int) + else list(unet_params.transformer_depth) + ) + else: + transformer_layers_per_block = 1 + + vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) + + head_dim = unet_params.num_heads if "num_heads" in unet_params else None + use_linear_projection = ( + unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False + ) + if use_linear_projection: + # stable diffusion 2-base-512 and 2-768 + if head_dim is None: + head_dim_mult = unet_params.model_channels // unet_params.num_head_channels + head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)] + + class_embed_type = None + addition_embed_type = None + addition_time_embed_dim = None + projection_class_embeddings_input_dim = None + context_dim = None + + if unet_params.context_dim is not None: + context_dim = ( + unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0] + ) + + if "num_classes" in unet_params: + if unet_params.num_classes == "sequential": + if context_dim in [2048, 1280]: + # SDXL + addition_embed_type = "text_time" + addition_time_embed_dim = 256 + else: + class_embed_type = "projection" + assert "adm_in_channels" in unet_params + projection_class_embeddings_input_dim = unet_params.adm_in_channels + else: + raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}") + + config = { + "sample_size": image_size // vae_scale_factor, + "in_channels": unet_params.in_channels, + "down_block_types": tuple(down_block_types), + "block_out_channels": tuple(block_out_channels), + "layers_per_block": unet_params.num_res_blocks, + "cross_attention_dim": context_dim, + "attention_head_dim": head_dim, + "use_linear_projection": use_linear_projection, + "class_embed_type": class_embed_type, + "addition_embed_type": addition_embed_type, + "addition_time_embed_dim": addition_time_embed_dim, + "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, + "transformer_layers_per_block": transformer_layers_per_block, + } + + if controlnet: + config["conditioning_channels"] = unet_params.hint_channels + else: + config["out_channels"] = unet_params.out_channels + config["up_block_types"] = tuple(up_block_types) + + return config + + +def create_vae_diffusers_config(original_config, image_size: int): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + vae_params = original_config.model.params.first_stage_config.params.ddconfig + _ = original_config.model.params.first_stage_config.params.embed_dim + + block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) + + config = { + "sample_size": image_size, + "in_channels": vae_params.in_channels, + "out_channels": vae_params.out_ch, + "down_block_types": tuple(down_block_types), + "up_block_types": tuple(up_block_types), + "block_out_channels": tuple(block_out_channels), + "latent_channels": vae_params.z_channels, + "layers_per_block": vae_params.num_res_blocks, + } + return config + + +def create_diffusers_schedular(original_config): + schedular = DDIMScheduler( + num_train_timesteps=original_config.model.params.timesteps, + beta_start=original_config.model.params.linear_start, + beta_end=original_config.model.params.linear_end, + beta_schedule="scaled_linear", + ) + return schedular + + +def create_ldm_bert_config(original_config): + bert_params = original_config.model.parms.cond_stage_config.params + config = LDMBertConfig( + d_model=bert_params.n_embed, + encoder_layers=bert_params.n_layer, + encoder_ffn_dim=bert_params.n_embed * 4, + ) + return config + + +def convert_ldm_unet_checkpoint( + checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False +): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + if skip_extract_state_dict: + unet_state_dict = checkpoint + else: + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + + if controlnet: + unet_key = "control_model." + else: + unet_key = "model.diffusion_model." + + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.") + logger.warning( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + if sum(k.startswith("model_ema") for k in keys) > 100: + logger.warning( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + if config["class_embed_type"] is None: + # No parameters to port + ... + elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": + new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + else: + raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") + + if config["addition_embed_type"] == "text_time": + new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + if not controlnet: + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + output_block_list = {k: sorted(v) for k, v in output_block_list.items()} + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + if controlnet: + # conditioning embedding + + orig_index = 0 + + new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + + orig_index += 2 + + diffusers_index = 0 + + while diffusers_index < 6: + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + diffusers_index += 1 + orig_index += 2 + + new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + + # down blocks + for i in range(num_input_blocks): + new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight") + new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias") + + # mid block + new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight") + new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias") + + return new_checkpoint + + +def convert_ldm_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + vae_state_dict = {} + keys = list(checkpoint.keys()) + vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else "" + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + + +def convert_ldm_bert_checkpoint(checkpoint, config): + def _copy_attn_layer(hf_attn_layer, pt_attn_layer): + hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight + hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight + hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight + + hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight + hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias + + def _copy_linear(hf_linear, pt_linear): + hf_linear.weight = pt_linear.weight + hf_linear.bias = pt_linear.bias + + def _copy_layer(hf_layer, pt_layer): + # copy layer norms + _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0]) + _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0]) + + # copy attn + _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1]) + + # copy MLP + pt_mlp = pt_layer[1][1] + _copy_linear(hf_layer.fc1, pt_mlp.net[0][0]) + _copy_linear(hf_layer.fc2, pt_mlp.net[2]) + + def _copy_layers(hf_layers, pt_layers): + for i, hf_layer in enumerate(hf_layers): + if i != 0: + i += i + pt_layer = pt_layers[i : i + 2] + _copy_layer(hf_layer, pt_layer) + + hf_model = LDMBertModel(config).eval() + + # copy embeds + hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight + hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight + + # copy layer norm + _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) + + # copy hidden layers + _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers) + + _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits) + + return hf_model + + +def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None): + if text_encoder is None: + config = CLIPTextConfig.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + text_model = CLIPTextModel(config) + + keys = list(checkpoint.keys()) + + text_model_dict = {} + + remove_prefixes = ["cond_stage_model.transformer", "conditioner.embedders.0.transformer"] + + for key in keys: + for prefix in remove_prefixes: + if key.startswith(prefix): + text_model_dict[key[len(prefix + ".") :]] = checkpoint[key] + + if is_accelerate_available(): + for param_name, param in text_model_dict.items(): + set_module_tensor_to_device(text_model, param_name, "cpu", value=param) + else: + text_model.load_state_dict(text_model_dict) + + return text_model + + +textenc_conversion_lst = [ + ("positional_embedding", "text_model.embeddings.position_embedding.weight"), + ("token_embedding.weight", "text_model.embeddings.token_embedding.weight"), + ("ln_final.weight", "text_model.final_layer_norm.weight"), + ("ln_final.bias", "text_model.final_layer_norm.bias"), + ("text_projection", "text_projection.weight"), +] +textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst} + +textenc_transformer_conversion_lst = [ + # (stable-diffusion, HF Diffusers) + ("resblocks.", "text_model.encoder.layers."), + ("ln_1", "layer_norm1"), + ("ln_2", "layer_norm2"), + (".c_fc.", ".fc1."), + (".c_proj.", ".fc2."), + (".attn", ".self_attn"), + ("ln_final.", "transformer.text_model.final_layer_norm."), + ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), + ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), +] +protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst} +textenc_pattern = re.compile("|".join(protected.keys())) + + +def convert_paint_by_example_checkpoint(checkpoint): + config = CLIPVisionConfig.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") + model = PaintByExampleImageEncoder(config) + + keys = list(checkpoint.keys()) + + text_model_dict = {} + + for key in keys: + if key.startswith("cond_stage_model.transformer"): + text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] + + # load clip vision + model.model.load_state_dict(text_model_dict) + + # load mapper + keys_mapper = { + k[len("cond_stage_model.mapper.res") :]: v + for k, v in checkpoint.items() + if k.startswith("cond_stage_model.mapper") + } + + MAPPING = { + "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"], + "attn.c_proj": ["attn1.to_out.0"], + "ln_1": ["norm1"], + "ln_2": ["norm3"], + "mlp.c_fc": ["ff.net.0.proj"], + "mlp.c_proj": ["ff.net.2"], + } + + mapped_weights = {} + for key, value in keys_mapper.items(): + prefix = key[: len("blocks.i")] + suffix = key.split(prefix)[-1].split(".")[-1] + name = key.split(prefix)[-1].split(suffix)[0][1:-1] + mapped_names = MAPPING[name] + + num_splits = len(mapped_names) + for i, mapped_name in enumerate(mapped_names): + new_name = ".".join([prefix, mapped_name, suffix]) + shape = value.shape[0] // num_splits + mapped_weights[new_name] = value[i * shape : (i + 1) * shape] + + model.mapper.load_state_dict(mapped_weights) + + # load final layer norm + model.final_layer_norm.load_state_dict( + { + "bias": checkpoint["cond_stage_model.final_ln.bias"], + "weight": checkpoint["cond_stage_model.final_ln.weight"], + } + ) + + # load final proj + model.proj_out.load_state_dict( + { + "bias": checkpoint["proj_out.bias"], + "weight": checkpoint["proj_out.weight"], + } + ) + + # load uncond vector + model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"]) + return model + + +def convert_open_clip_checkpoint( + checkpoint, config_name, prefix="cond_stage_model.model.", has_projection=False, **config_kwargs +): + # text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder") + # text_model = CLIPTextModelWithProjection.from_pretrained( + # "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280 + # ) + config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs) + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + text_model = CLIPTextModelWithProjection(config) if has_projection else CLIPTextModel(config) + + keys = list(checkpoint.keys()) + + keys_to_ignore = [] + if config_name == "stabilityai/stable-diffusion-2" and config.num_hidden_layers == 23: + # make sure to remove all keys > 22 + keys_to_ignore += [k for k in keys if k.startswith("cond_stage_model.model.transformer.resblocks.23")] + keys_to_ignore += ["cond_stage_model.model.text_projection"] + + text_model_dict = {} + + if prefix + "text_projection" in checkpoint: + d_model = int(checkpoint[prefix + "text_projection"].shape[0]) + else: + d_model = 1024 + + text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids") + + for key in keys: + if key in keys_to_ignore: + continue + if key[len(prefix) :] in textenc_conversion_map: + if key.endswith("text_projection"): + value = checkpoint[key].T.contiguous() + else: + value = checkpoint[key] + + text_model_dict[textenc_conversion_map[key[len(prefix) :]]] = value + + if key.startswith(prefix + "transformer."): + new_key = key[len(prefix + "transformer.") :] + if new_key.endswith(".in_proj_weight"): + new_key = new_key[: -len(".in_proj_weight")] + new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :] + text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :] + text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :] + elif new_key.endswith(".in_proj_bias"): + new_key = new_key[: -len(".in_proj_bias")] + new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model] + text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2] + text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :] + else: + new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + + text_model_dict[new_key] = checkpoint[key] + + if is_accelerate_available(): + for param_name, param in text_model_dict.items(): + set_module_tensor_to_device(text_model, param_name, "cpu", value=param) + else: + text_model.load_state_dict(text_model_dict) + + return text_model + + +def stable_unclip_image_encoder(original_config): + """ + Returns the image processor and clip image encoder for the img2img unclip pipeline. + + We currently know of two types of stable unclip models which separately use the clip and the openclip image + encoders. + """ + + image_embedder_config = original_config.model.params.embedder_config + + sd_clip_image_embedder_class = image_embedder_config.target + sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1] + + if sd_clip_image_embedder_class == "ClipImageEmbedder": + clip_model_name = image_embedder_config.params.model + + if clip_model_name == "ViT-L/14": + feature_extractor = CLIPImageProcessor() + image_encoder = CLIPVisionModelWithProjection.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") + else: + raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}") + + elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder": + feature_extractor = CLIPImageProcessor() + # InvokeAI doesn't use CLIPVisionModelWithProjection so it isn't in the core - if this code is hit a download will occur + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + CONVERT_MODEL_ROOT / "CLIP-ViT-H-14-laion2B-s32B-b79K" + ) + else: + raise NotImplementedError( + f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}" + ) + + return feature_extractor, image_encoder + + +def stable_unclip_image_noising_components( + original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None +): + """ + Returns the noising components for the img2img and txt2img unclip pipelines. + + Converts the stability noise augmentor into + 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats + 2. a `DDPMScheduler` for holding the noise schedule + + If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided. + """ + noise_aug_config = original_config.model.params.noise_aug_config + noise_aug_class = noise_aug_config.target + noise_aug_class = noise_aug_class.split(".")[-1] + + if noise_aug_class == "CLIPEmbeddingNoiseAugmentation": + noise_aug_config = noise_aug_config.params + embedding_dim = noise_aug_config.timestep_dim + max_noise_level = noise_aug_config.noise_schedule_config.timesteps + beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule + + image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim) + image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule) + + if "clip_stats_path" in noise_aug_config: + if clip_stats_path is None: + raise ValueError("This stable unclip config requires a `clip_stats_path`") + + clip_mean, clip_std = torch.load(clip_stats_path, map_location=device) + clip_mean = clip_mean[None, :] + clip_std = clip_std[None, :] + + clip_stats_state_dict = { + "mean": clip_mean, + "std": clip_std, + } + + image_normalizer.load_state_dict(clip_stats_state_dict) + else: + raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}") + + return image_normalizer, image_noising_scheduler + + +def convert_controlnet_checkpoint( + checkpoint, + original_config, + checkpoint_path, + image_size, + upcast_attention, + extract_ema, + use_linear_projection=None, + cross_attention_dim=None, + precision: Optional[torch.dtype] = None, +): + ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True) + ctrlnet_config["upcast_attention"] = upcast_attention + + ctrlnet_config.pop("sample_size") + original_config = ctrlnet_config.copy() + + ctrlnet_config.pop("addition_embed_type") + ctrlnet_config.pop("addition_time_embed_dim") + ctrlnet_config.pop("transformer_layers_per_block") + + if use_linear_projection is not None: + ctrlnet_config["use_linear_projection"] = use_linear_projection + + if cross_attention_dim is not None: + ctrlnet_config["cross_attention_dim"] = cross_attention_dim + + controlnet = ControlNetModel(**ctrlnet_config) + + # Some controlnet ckpt files are distributed independently from the rest of the + # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ + if "time_embed.0.weight" in checkpoint: + skip_extract_state_dict = True + else: + skip_extract_state_dict = False + + converted_ctrl_checkpoint = convert_ldm_unet_checkpoint( + checkpoint, + original_config, + path=checkpoint_path, + extract_ema=extract_ema, + controlnet=True, + skip_extract_state_dict=skip_extract_state_dict, + ) + + controlnet.load_state_dict(converted_ctrl_checkpoint) + + return controlnet.to(precision) + + +def download_from_original_stable_diffusion_ckpt( + checkpoint_path: str, + model_version: BaseModelType, + model_variant: ModelVariantType, + original_config_file: str = None, + image_size: Optional[int] = None, + prediction_type: str = None, + model_type: str = None, + extract_ema: bool = False, + precision: Optional[torch.dtype] = None, + scheduler_type: str = "pndm", + num_in_channels: Optional[int] = None, + upcast_attention: Optional[bool] = None, + device: str = None, + from_safetensors: bool = False, + stable_unclip: Optional[str] = None, + stable_unclip_prior: Optional[str] = None, + clip_stats_path: Optional[str] = None, + controlnet: Optional[bool] = None, + load_safety_checker: bool = True, + pipeline_class: DiffusionPipeline = None, + local_files_only=False, + vae_path=None, + text_encoder=None, + tokenizer=None, + scan_needed: bool = True, +) -> DiffusionPipeline: + """ + Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml` + config file. + + Although many of the arguments can be automatically inferred, some of these rely on brittle checks against the + global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is + recommended that you override the default values and/or supply an `original_config_file` wherever possible. + + Args: + checkpoint_path (`str`): Path to `.ckpt` file. + original_config_file (`str`): + Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically + inferred by looking for a key that only exists in SD2.0 models. + image_size (`int`, *optional*, defaults to 512): + The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2 + Base. Use 768 for Stable Diffusion v2. + prediction_type (`str`, *optional*): + The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion v1.X and Stable + Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2. + num_in_channels (`int`, *optional*, defaults to None): + The number of input channels. If `None`, it will be automatically inferred. + scheduler_type (`str`, *optional*, defaults to 'pndm'): + Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm", + "ddim"]`. + model_type (`str`, *optional*, defaults to `None`): + The pipeline type. `None` to automatically infer, or one of `["FrozenOpenCLIPEmbedder", + "FrozenCLIPEmbedder", "PaintByExample"]`. + is_img2img (`bool`, *optional*, defaults to `False`): + Whether the model should be loaded as an img2img pipeline. + extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for + checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to + `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for + inference. Non-EMA weights are usually better to continue fine-tuning. + upcast_attention (`bool`, *optional*, defaults to `None`): + Whether the attention computation should always be upcasted. This is necessary when running stable + diffusion 2.1. + device (`str`, *optional*, defaults to `None`): + The device to use. Pass `None` to determine automatically. + from_safetensors (`str`, *optional*, defaults to `False`): + If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch. + load_safety_checker (`bool`, *optional*, defaults to `True`): + Whether to load the safety checker or not. Defaults to `True`. + pipeline_class (`str`, *optional*, defaults to `None`): + The pipeline class to use. Pass `None` to determine automatically. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + text_encoder (`CLIPTextModel`, *optional*, defaults to `None`): + An instance of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) + to use, specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) + variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed. + tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`): + An instance of + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer) + to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by itself, if + needed. + precision (`torch.dtype`, *optional*, defauts to `None`): + If not provided the precision will be set to the precision of the original file. + return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file. + """ + + # import pipelines here to avoid circular import error when using from_single_file method + from diffusers import ( + LDMTextToImagePipeline, + PaintByExamplePipeline, + StableDiffusionControlNetPipeline, + StableDiffusionInpaintPipeline, + StableDiffusionPipeline, + StableDiffusionXLImg2ImgPipeline, + StableDiffusionXLPipeline, + StableUnCLIPImg2ImgPipeline, + StableUnCLIPPipeline, + ) + + if pipeline_class is None: + pipeline_class = StableDiffusionPipeline if not controlnet else StableDiffusionControlNetPipeline + + if prediction_type == "v-prediction": + prediction_type = "v_prediction" + + if from_safetensors: + from safetensors.torch import load_file as safe_load + + checkpoint = safe_load(checkpoint_path, device="cpu") + else: + if scan_needed: + # scan model + scan_result = scan_file_path(checkpoint_path) + if scan_result.infected_files != 0: + raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.") + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + checkpoint = torch.load(checkpoint_path, map_location=device) + else: + checkpoint = torch.load(checkpoint_path, map_location=device) + + # Sometimes models don't have the global_step item + if "global_step" in checkpoint: + global_step = checkpoint["global_step"] + else: + logger.debug("global_step key not found in model") + global_step = None + + # NOTE: this while loop isn't great but this controlnet checkpoint has one additional + # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 + while "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + logger.debug(f"model_type = {model_type}; original_config_file = {original_config_file}") + + precision_probing_key = "model.diffusion_model.input_blocks.0.0.bias" + logger.debug(f"original checkpoint precision == {checkpoint[precision_probing_key].dtype}") + precision = precision or checkpoint[precision_probing_key].dtype + + if original_config_file is None: + key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" + key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias" + key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias" + + # model_type = "v1" + config_url = ( + "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" + ) + + if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024: + # model_type = "v2" + config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml" + + if global_step == 110000: + # v2.1 needs to upcast attention + upcast_attention = True + elif key_name_sd_xl_base in checkpoint: + # only base xl has two text embedders + config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml" + elif key_name_sd_xl_refiner in checkpoint: + # only refiner xl has embedder and one text embedders + config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml" + + original_config_file = BytesIO(requests.get(config_url).content) + + original_config = OmegaConf.load(original_config_file) + if original_config["model"]["params"].get("use_ema") is not None: + extract_ema = original_config["model"]["params"]["use_ema"] + + if ( + model_version in [BaseModelType.StableDiffusion2, BaseModelType.StableDiffusion1] + and original_config["model"]["params"].get("parameterization") == "v" + ): + prediction_type = "v_prediction" + upcast_attention = True + image_size = 768 if model_version == BaseModelType.StableDiffusion2 else 512 + else: + prediction_type = "epsilon" + upcast_attention = False + image_size = 512 + + # Convert the text model. + if ( + model_type is None + and "cond_stage_config" in original_config.model.params + and original_config.model.params.cond_stage_config is not None + ): + model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] + logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}") + elif model_type is None and original_config.model.params.network_config is not None: + if original_config.model.params.network_config.params.context_dim == 2048: + model_type = "SDXL" + else: + model_type = "SDXL-Refiner" + if image_size is None: + image_size = 1024 + + if num_in_channels is None and pipeline_class == StableDiffusionInpaintPipeline: + num_in_channels = 9 + elif num_in_channels is None: + num_in_channels = 4 + + if "unet_config" in original_config.model.params: + original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels + + if ( + "parameterization" in original_config["model"]["params"] + and original_config["model"]["params"]["parameterization"] == "v" + ): + if prediction_type is None: + # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` + # as it relies on a brittle global step parameter here + prediction_type = "epsilon" if global_step == 875000 else "v_prediction" + if image_size is None: + # NOTE: For stable diffusion 2 base one has to pass `image_size==512` + # as it relies on a brittle global step parameter here + image_size = 512 if global_step == 875000 else 768 + else: + if prediction_type is None: + prediction_type = "epsilon" + if image_size is None: + image_size = 512 + + if controlnet is None and "control_stage_config" in original_config.model.params: + controlnet = convert_controlnet_checkpoint( + checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema + ) + + num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000 + + if model_type in ["SDXL", "SDXL-Refiner"]: + scheduler_dict = { + "beta_schedule": "scaled_linear", + "beta_start": 0.00085, + "beta_end": 0.012, + "interpolation_type": "linear", + "num_train_timesteps": num_train_timesteps, + "prediction_type": "epsilon", + "sample_max_value": 1.0, + "set_alpha_to_one": False, + "skip_prk_steps": True, + "steps_offset": 1, + "timestep_spacing": "leading", + } + scheduler = EulerDiscreteScheduler.from_config(scheduler_dict) + scheduler_type = "euler" + else: + beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02 + beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085 + scheduler = DDIMScheduler( + beta_end=beta_end, + beta_schedule="scaled_linear", + beta_start=beta_start, + num_train_timesteps=num_train_timesteps, + steps_offset=1, + clip_sample=False, + set_alpha_to_one=False, + prediction_type=prediction_type, + ) + # make sure scheduler works correctly with DDIM + scheduler.register_to_config(clip_sample=False) + + if scheduler_type == "pndm": + config = dict(scheduler.config) + config["skip_prk_steps"] = True + scheduler = PNDMScheduler.from_config(config) + elif scheduler_type == "lms": + scheduler = LMSDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "heun": + scheduler = HeunDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "euler": + scheduler = EulerDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "euler-ancestral": + scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "dpm": + scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config) + elif scheduler_type == "ddim": + scheduler = scheduler + else: + raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") + + # Convert the UNet2DConditionModel model. + unet_config = create_unet_diffusers_config(original_config, image_size=image_size) + unet_config["upcast_attention"] = upcast_attention + converted_unet_checkpoint = convert_ldm_unet_checkpoint( + checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema + ) + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + unet = UNet2DConditionModel(**unet_config) + + if is_accelerate_available(): + for param_name, param in converted_unet_checkpoint.items(): + set_module_tensor_to_device(unet, param_name, "cpu", value=param) + else: + unet.load_state_dict(converted_unet_checkpoint) + + # Convert the VAE model. + if vae_path is None: + vae_config = create_vae_diffusers_config(original_config, image_size=image_size) + converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) + + if ( + "model" in original_config + and "params" in original_config.model + and "scale_factor" in original_config.model.params + ): + vae_scaling_factor = original_config.model.params.scale_factor + else: + vae_scaling_factor = 0.18215 # default SD scaling factor + + vae_config["scaling_factor"] = vae_scaling_factor + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + vae = AutoencoderKL(**vae_config) + + if is_accelerate_available(): + for param_name, param in converted_vae_checkpoint.items(): + set_module_tensor_to_device(vae, param_name, "cpu", value=param) + else: + vae.load_state_dict(converted_vae_checkpoint) + else: + vae = AutoencoderKL.from_pretrained(vae_path) + + if model_type == "FrozenOpenCLIPEmbedder": + config_name = "stabilityai/stable-diffusion-2" + config_kwargs = {"subfolder": "text_encoder"} + + text_model = convert_open_clip_checkpoint(checkpoint, config_name, **config_kwargs) + tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "stable-diffusion-2-clip", subfolder="tokenizer") + + if stable_unclip is None: + if controlnet: + pipe = pipeline_class( + vae=vae.to(precision), + text_encoder=text_model.to(precision), + tokenizer=tokenizer, + unet=unet.to(precision), + scheduler=scheduler, + controlnet=controlnet, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + else: + pipe = pipeline_class( + vae=vae.to(precision), + text_encoder=text_model.to(precision), + tokenizer=tokenizer, + unet=unet.to(precision), + scheduler=scheduler, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + else: + image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components( + original_config, clip_stats_path=clip_stats_path, device=device + ) + + if stable_unclip == "img2img": + feature_extractor, image_encoder = stable_unclip_image_encoder(original_config) + + pipe = StableUnCLIPImg2ImgPipeline( + # image encoding components + feature_extractor=feature_extractor, + image_encoder=image_encoder, + # image noising components + image_normalizer=image_normalizer, + image_noising_scheduler=image_noising_scheduler, + # regular denoising components + tokenizer=tokenizer, + text_encoder=text_model.to(precision), + unet=unet.to(precision), + scheduler=scheduler, + # vae + vae=vae, + ) + elif stable_unclip == "txt2img": + if stable_unclip_prior is None or stable_unclip_prior == "karlo": + karlo_model = "kakaobrain/karlo-v1-alpha" + prior = PriorTransformer.from_pretrained(karlo_model, subfolder="prior") + + prior_tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") + prior_text_model = CLIPTextModelWithProjection.from_pretrained( + CONVERT_MODEL_ROOT / "clip-vit-large-patch14" + ) + + prior_scheduler = UnCLIPScheduler.from_pretrained(karlo_model, subfolder="prior_scheduler") + prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config) + else: + raise NotImplementedError(f"unknown prior for stable unclip model: {stable_unclip_prior}") + + pipe = StableUnCLIPPipeline( + # prior components + prior_tokenizer=prior_tokenizer, + prior_text_encoder=prior_text_model, + prior=prior, + prior_scheduler=prior_scheduler, + # image noising components + image_normalizer=image_normalizer, + image_noising_scheduler=image_noising_scheduler, + # regular denoising components + tokenizer=tokenizer, + text_encoder=text_model, + unet=unet, + scheduler=scheduler, + # vae + vae=vae, + ) + else: + raise NotImplementedError(f"unknown `stable_unclip` type: {stable_unclip}") + elif model_type == "PaintByExample": + vision_model = convert_paint_by_example_checkpoint(checkpoint) + tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") + feature_extractor = AutoFeatureExtractor.from_pretrained(CONVERT_MODEL_ROOT / "stable-diffusion-safety-checker") + pipe = PaintByExamplePipeline( + vae=vae, + image_encoder=vision_model, + unet=unet, + scheduler=scheduler, + safety_checker=None, + feature_extractor=feature_extractor, + ) + elif model_type == "FrozenCLIPEmbedder": + text_model = convert_ldm_clip_checkpoint( + checkpoint, local_files_only=local_files_only, text_encoder=text_encoder + ) + tokenizer = ( + CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") + if tokenizer is None + else tokenizer + ) + + if load_safety_checker: + safety_checker = StableDiffusionSafetyChecker.from_pretrained( + CONVERT_MODEL_ROOT / "stable-diffusion-safety-checker" + ) + feature_extractor = AutoFeatureExtractor.from_pretrained( + CONVERT_MODEL_ROOT / "stable-diffusion-safety-checker" + ) + else: + safety_checker = None + feature_extractor = None + + if controlnet: + pipe = pipeline_class( + vae=vae.to(precision), + text_encoder=text_model.to(precision), + tokenizer=tokenizer, + unet=unet.to(precision), + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + else: + pipe = pipeline_class( + vae=vae.to(precision), + text_encoder=text_model.to(precision), + tokenizer=tokenizer, + unet=unet.to(precision), + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + elif model_type in ["SDXL", "SDXL-Refiner"]: + if model_type == "SDXL": + tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") + text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only) + + tokenizer_name = CONVERT_MODEL_ROOT / "CLIP-ViT-bigG-14-laion2B-39B-b160k" + tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_name, pad_token="!") + + config_name = tokenizer_name + config_kwargs = {"projection_dim": 1280} + text_encoder_2 = convert_open_clip_checkpoint( + checkpoint, config_name, prefix="conditioner.embedders.1.model.", has_projection=True, **config_kwargs + ) + + pipe = StableDiffusionXLPipeline( + vae=vae.to(precision), + text_encoder=text_encoder.to(precision), + tokenizer=tokenizer, + text_encoder_2=text_encoder_2.to(precision), + tokenizer_2=tokenizer_2, + unet=unet.to(precision), + scheduler=scheduler, + force_zeros_for_empty_prompt=True, + ) + else: + tokenizer = None + text_encoder = None + tokenizer_name = CONVERT_MODEL_ROOT / "CLIP-ViT-bigG-14-laion2B-39B-b160k" + tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_name, pad_token="!") + + config_name = tokenizer_name + config_kwargs = {"projection_dim": 1280} + text_encoder_2 = convert_open_clip_checkpoint( + checkpoint, config_name, prefix="conditioner.embedders.0.model.", has_projection=True, **config_kwargs + ) + + pipe = StableDiffusionXLImg2ImgPipeline( + vae=vae.to(precision), + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + unet=unet.to(precision), + scheduler=scheduler, + requires_aesthetics_score=True, + force_zeros_for_empty_prompt=False, + ) + else: + text_config = create_ldm_bert_config(original_config) + text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) + tokenizer = BertTokenizerFast.from_pretrained(CONVERT_MODEL_ROOT / "bert-base-uncased") + pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler) + + return pipe + + +def download_controlnet_from_original_ckpt( + checkpoint_path: str, + original_config_file: str, + image_size: int = 512, + extract_ema: bool = False, + precision: Optional[torch.dtype] = None, + num_in_channels: Optional[int] = None, + upcast_attention: Optional[bool] = None, + device: str = None, + from_safetensors: bool = False, + use_linear_projection: Optional[bool] = None, + cross_attention_dim: Optional[bool] = None, + scan_needed: bool = False, +) -> DiffusionPipeline: + from omegaconf import OmegaConf + + if from_safetensors: + from safetensors import safe_open + + checkpoint = {} + with safe_open(checkpoint_path, framework="pt", device="cpu") as f: + for key in f.keys(): + checkpoint[key] = f.get_tensor(key) + else: + if scan_needed: + # scan model + scan_result = scan_file_path(checkpoint_path) + if scan_result.infected_files != 0: + raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.") + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + checkpoint = torch.load(checkpoint_path, map_location=device) + else: + checkpoint = torch.load(checkpoint_path, map_location=device) + + # NOTE: this while loop isn't great but this controlnet checkpoint has one additional + # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 + while "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + # use original precision + precision_probing_key = "input_blocks.0.0.bias" + ckpt_precision = checkpoint[precision_probing_key].dtype + logger.debug(f"original controlnet precision = {ckpt_precision}") + precision = precision or ckpt_precision + + original_config = OmegaConf.load(original_config_file) + + if num_in_channels is not None: + original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels + + if "control_stage_config" not in original_config.model.params: + raise ValueError("`control_stage_config` not present in original config") + + controlnet = convert_controlnet_checkpoint( + checkpoint, + original_config, + checkpoint_path, + image_size, + upcast_attention, + extract_ema, + use_linear_projection=use_linear_projection, + cross_attention_dim=cross_attention_dim, + ) + + return controlnet.to(precision) + + +def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size: int) -> AutoencoderKL: + vae_config = create_vae_diffusers_config(vae_config, image_size=image_size) + + converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) + + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_checkpoint) + return vae + + +def convert_ckpt_to_diffusers( + checkpoint_path: Union[str, Path], + dump_path: Union[str, Path], + use_safetensors: bool = True, + **kwargs, +): + """ + Takes all the arguments of download_from_original_stable_diffusion_ckpt(), + and in addition a path-like object indicating the location of the desired diffusers + model to be written. + """ + pipe = download_from_original_stable_diffusion_ckpt(checkpoint_path, **kwargs) + + # TO DO: save correct repo variant + pipe.save_pretrained( + dump_path, + safe_serialization=use_safetensors, + ) + + +def convert_controlnet_to_diffusers( + checkpoint_path: Union[str, Path], + dump_path: Union[str, Path], + **kwargs, +): + """ + Takes all the arguments of download_controlnet_from_original_ckpt(), + and in addition a path-like object indicating the location of the desired diffusers + model to be written. + """ + pipe = download_controlnet_from_original_ckpt(checkpoint_path, **kwargs) + + # TO DO: save correct repo variant + pipe.save_pretrained(dump_path, safe_serialization=True) diff --git a/invokeai/backend/model_manager/load/__init__.py b/invokeai/backend/model_manager/load/__init__.py new file mode 100644 index 00000000000..966a739237a --- /dev/null +++ b/invokeai/backend/model_manager/load/__init__.py @@ -0,0 +1,34 @@ +# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development Team +""" +Init file for the model loader. +""" +from importlib import import_module +from pathlib import Path +from typing import Optional + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend.util.logging import InvokeAILogger + +from .convert_cache.convert_cache_default import ModelConvertCache +from .load_base import AnyModelLoader, LoadedModel +from .model_cache.model_cache_default import ModelCache + +# This registers the subclasses that implement loaders of specific model types +loaders = [x.stem for x in Path(Path(__file__).parent, "model_loaders").glob("*.py") if x.stem != "__init__"] +for module in loaders: + import_module(f"{__package__}.model_loaders.{module}") + +__all__ = ["AnyModelLoader", "LoadedModel", "ModelCache", "ModelConvertCache"] + + +def get_standalone_loader(app_config: Optional[InvokeAIAppConfig]) -> AnyModelLoader: + app_config = app_config or InvokeAIAppConfig.get_config() + logger = InvokeAILogger.get_logger(config=app_config) + return AnyModelLoader( + app_config=app_config, + logger=logger, + ram_cache=ModelCache( + logger=logger, max_cache_size=app_config.ram_cache_size, max_vram_cache_size=app_config.vram_cache_size + ), + convert_cache=ModelConvertCache(app_config.models_convert_cache_path), + ) diff --git a/invokeai/backend/model_manager/load/convert_cache/__init__.py b/invokeai/backend/model_manager/load/convert_cache/__init__.py new file mode 100644 index 00000000000..5be56d2d584 --- /dev/null +++ b/invokeai/backend/model_manager/load/convert_cache/__init__.py @@ -0,0 +1,4 @@ +from .convert_cache_base import ModelConvertCacheBase +from .convert_cache_default import ModelConvertCache + +__all__ = ["ModelConvertCacheBase", "ModelConvertCache"] diff --git a/invokeai/backend/model_manager/load/convert_cache/convert_cache_base.py b/invokeai/backend/model_manager/load/convert_cache/convert_cache_base.py new file mode 100644 index 00000000000..6268c099a5f --- /dev/null +++ b/invokeai/backend/model_manager/load/convert_cache/convert_cache_base.py @@ -0,0 +1,27 @@ +""" +Disk-based converted model cache. +""" +from abc import ABC, abstractmethod +from pathlib import Path + + +class ModelConvertCacheBase(ABC): + @property + @abstractmethod + def max_size(self) -> float: + """Return the maximum size of this cache directory.""" + pass + + @abstractmethod + def make_room(self, size: float) -> None: + """ + Make sufficient room in the cache directory for a model of max_size. + + :param size: Size required (GB) + """ + pass + + @abstractmethod + def cache_path(self, key: str) -> Path: + """Return the path for a model with the indicated key.""" + pass diff --git a/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py b/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py new file mode 100644 index 00000000000..84f4f76299a --- /dev/null +++ b/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py @@ -0,0 +1,72 @@ +""" +Placeholder for convert cache implementation. +""" + +import shutil +from pathlib import Path + +from invokeai.backend.util import GIG, directory_size +from invokeai.backend.util.logging import InvokeAILogger + +from .convert_cache_base import ModelConvertCacheBase + + +class ModelConvertCache(ModelConvertCacheBase): + def __init__(self, cache_path: Path, max_size: float = 10.0): + """Initialize the convert cache with the base directory and a limit on its maximum size (in GBs).""" + if not cache_path.exists(): + cache_path.mkdir(parents=True) + self._cache_path = cache_path + self._max_size = max_size + + @property + def max_size(self) -> float: + """Return the maximum size of this cache directory (GB).""" + return self._max_size + + def cache_path(self, key: str) -> Path: + """Return the path for a model with the indicated key.""" + return self._cache_path / key + + def make_room(self, size: float) -> None: + """ + Make sufficient room in the cache directory for a model of max_size. + + :param size: Size required (GB) + """ + size_needed = directory_size(self._cache_path) + size + max_size = int(self.max_size) * GIG + logger = InvokeAILogger.get_logger() + + if size_needed <= max_size: + return + + logger.debug( + f"Convert cache has gotten too large {(size_needed / GIG):4.2f} > {(max_size / GIG):4.2f}G.. Trimming." + ) + + # For this to work, we make the assumption that the directory contains + # a 'model_index.json', 'unet/config.json' file, or a 'config.json' file at top level. + # This should be true for any diffusers model. + def by_atime(path: Path) -> float: + for config in ["model_index.json", "unet/config.json", "config.json"]: + sentinel = path / config + if sentinel.exists(): + return sentinel.stat().st_atime + + # no sentinel file found! - pick the most recent file in the directory + try: + atimes = sorted([x.stat().st_atime for x in path.iterdir() if x.is_file()], reverse=True) + return atimes[0] + except IndexError: + return 0.0 + + # sort by last access time - least accessed files will be at the end + lru_models = sorted(self._cache_path.iterdir(), key=by_atime, reverse=True) + logger.debug(f"cached models in descending atime order: {lru_models}") + while size_needed > max_size and len(lru_models) > 0: + next_victim = lru_models.pop() + victim_size = directory_size(next_victim) + logger.debug(f"Removing cached converted model {next_victim} to free {victim_size / GIG} GB") + shutil.rmtree(next_victim) + size_needed -= victim_size diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py new file mode 100644 index 00000000000..7649dee762b --- /dev/null +++ b/invokeai/backend/model_manager/load/load_base.py @@ -0,0 +1,195 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +""" +Base class for model loading in InvokeAI. + +Use like this: + + loader = AnyModelLoader(...) + loaded_model = loader.get_model('019ab39adfa1840455') + with loaded_model as model: # context manager moves model into VRAM + # do something with loaded_model +""" + +import hashlib +from abc import ABC, abstractmethod +from dataclasses import dataclass +from logging import Logger +from pathlib import Path +from typing import Any, Callable, Dict, Optional, Tuple, Type + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend.model_manager.config import ( + AnyModel, + AnyModelConfig, + BaseModelType, + ModelConfigBase, + ModelFormat, + ModelType, + SubModelType, + VaeCheckpointConfig, + VaeDiffusersConfig, +) +from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase +from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase +from invokeai.backend.util.logging import InvokeAILogger + + +@dataclass +class LoadedModel: + """Context manager object that mediates transfer from RAM<->VRAM.""" + + config: AnyModelConfig + locker: ModelLockerBase + + def __enter__(self) -> AnyModel: + """Context entry.""" + self.locker.lock() + return self.model + + def __exit__(self, *args: Any, **kwargs: Any) -> None: + """Context exit.""" + self.locker.unlock() + + @property + def model(self) -> AnyModel: + """Return the model without locking it.""" + return self.locker.model + + +class ModelLoaderBase(ABC): + """Abstract base class for loading models into RAM/VRAM.""" + + @abstractmethod + def __init__( + self, + app_config: InvokeAIAppConfig, + logger: Logger, + ram_cache: ModelCacheBase[AnyModel], + convert_cache: ModelConvertCacheBase, + ): + """Initialize the loader.""" + pass + + @abstractmethod + def load_model(self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None) -> LoadedModel: + """ + Return a model given its confguration. + + Given a model identified in the model configuration backend, + return a ModelInfo object that can be used to retrieve the model. + + :param model_config: Model configuration, as returned by ModelConfigRecordStore + :param submodel_type: an ModelType enum indicating the portion of + the model to retrieve (e.g. ModelType.Vae) + """ + pass + + @abstractmethod + def get_size_fs( + self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None + ) -> int: + """Return size in bytes of the model, calculated before loading.""" + pass + + +# TO DO: Better name? +class AnyModelLoader: + """This class manages the model loaders and invokes the correct one to load a model of given base and type.""" + + # this tracks the loader subclasses + _registry: Dict[str, Type[ModelLoaderBase]] = {} + _logger: Logger = InvokeAILogger.get_logger() + + def __init__( + self, + app_config: InvokeAIAppConfig, + logger: Logger, + ram_cache: ModelCacheBase[AnyModel], + convert_cache: ModelConvertCacheBase, + ): + """Initialize AnyModelLoader with its dependencies.""" + self._app_config = app_config + self._logger = logger + self._ram_cache = ram_cache + self._convert_cache = convert_cache + + @property + def ram_cache(self) -> ModelCacheBase[AnyModel]: + """Return the RAM cache associated used by the loaders.""" + return self._ram_cache + + @property + def convert_cache(self) -> ModelConvertCacheBase: + """Return the convert cache associated used by the loaders.""" + return self._convert_cache + + def load_model(self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None) -> LoadedModel: + """ + Return a model given its configuration. + + :param key: model key, as known to the config backend + :param submodel_type: an ModelType enum indicating the portion of + the model to retrieve (e.g. ModelType.Vae) + """ + implementation, model_config, submodel_type = self.__class__.get_implementation(model_config, submodel_type) + return implementation( + app_config=self._app_config, + logger=self._logger, + ram_cache=self._ram_cache, + convert_cache=self._convert_cache, + ).load_model(model_config, submodel_type) + + @staticmethod + def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) -> str: + return "-".join([base.value, type.value, format.value]) + + @classmethod + def get_implementation( + cls, config: ModelConfigBase, submodel_type: Optional[SubModelType] + ) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]: + """Get subclass of ModelLoaderBase registered to handle base and type.""" + # We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned + conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type) + + key1 = cls._to_registry_key(conf2.base, conf2.type, conf2.format) # for a specific base type + key2 = cls._to_registry_key(BaseModelType.Any, conf2.type, conf2.format) # with wildcard Any + implementation = cls._registry.get(key1) or cls._registry.get(key2) + if not implementation: + raise NotImplementedError( + f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}" + ) + return implementation, conf2, submodel_type + + @classmethod + def _handle_subtype_overrides( + cls, config: ModelConfigBase, submodel_type: Optional[SubModelType] + ) -> Tuple[ModelConfigBase, Optional[SubModelType]]: + if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None: + model_path = Path(config.vae) + config_class = ( + VaeCheckpointConfig if model_path.suffix in [".pt", ".safetensors", ".ckpt"] else VaeDiffusersConfig + ) + hash = hashlib.md5(model_path.as_posix().encode("utf-8")).hexdigest() + new_conf = config_class(path=model_path.as_posix(), name=model_path.stem, base=config.base, key=hash) + submodel_type = None + else: + new_conf = config + return new_conf, submodel_type + + @classmethod + def register( + cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any + ) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]: + """Define a decorator which registers the subclass of loader.""" + + def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]: + cls._logger.debug(f"Registering class {subclass.__name__} to load models of type {base}/{type}/{format}") + key = cls._to_registry_key(base, type, format) + if key in cls._registry: + raise Exception( + f"{subclass.__name__} is trying to register as a loader for {base}/{type}/{format}, but this type of model has already been registered by {cls._registry[key].__name__}" + ) + cls._registry[key] = subclass + return subclass + + return decorator diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py new file mode 100644 index 00000000000..1dac121a300 --- /dev/null +++ b/invokeai/backend/model_manager/load/load_default.py @@ -0,0 +1,188 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Default implementation of model loading in InvokeAI.""" + +import sys +from logging import Logger +from pathlib import Path +from typing import Any, Dict, Optional, Tuple + +from diffusers import ModelMixin +from diffusers.configuration_utils import ConfigMixin + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend.model_manager import ( + AnyModel, + AnyModelConfig, + InvalidModelConfigException, + ModelRepoVariant, + SubModelType, +) +from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase +from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase +from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase +from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs +from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init +from invokeai.backend.util.devices import choose_torch_device, torch_dtype + + +class ConfigLoader(ConfigMixin): + """Subclass of ConfigMixin for loading diffusers configuration files.""" + + @classmethod + def load_config(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]: + """Load a diffusrs ConfigMixin configuration.""" + cls.config_name = kwargs.pop("config_name") + # Diffusers doesn't provide typing info + return super().load_config(*args, **kwargs) # type: ignore + + +# TO DO: The loader is not thread safe! +class ModelLoader(ModelLoaderBase): + """Default implementation of ModelLoaderBase.""" + + def __init__( + self, + app_config: InvokeAIAppConfig, + logger: Logger, + ram_cache: ModelCacheBase[AnyModel], + convert_cache: ModelConvertCacheBase, + ): + """Initialize the loader.""" + self._app_config = app_config + self._logger = logger + self._ram_cache = ram_cache + self._convert_cache = convert_cache + self._torch_dtype = torch_dtype(choose_torch_device(), app_config) + + def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: + """ + Return a model given its configuration. + + Given a model's configuration as returned by the ModelRecordConfigStore service, + return a LoadedModel object that can be used for inference. + + :param model config: Configuration record for this model + :param submodel_type: an ModelType enum indicating the portion of + the model to retrieve (e.g. ModelType.Vae) + """ + if model_config.type == "main" and not submodel_type: + raise InvalidModelConfigException("submodel_type is required when loading a main model") + + model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type) + + if not model_path.exists(): + raise InvalidModelConfigException(f"Files for model '{model_config.name}' not found at {model_path}") + + model_path = self._convert_if_needed(model_config, model_path, submodel_type) + locker = self._load_if_needed(model_config, model_path, submodel_type) + return LoadedModel(config=model_config, locker=locker) + + def _get_model_path( + self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None + ) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]: + model_base = self._app_config.models_path + result = (model_base / config.path).resolve(), config, submodel_type + return result + + def _convert_if_needed( + self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None + ) -> Path: + cache_path: Path = self._convert_cache.cache_path(config.key) + + if not self._needs_conversion(config, model_path, cache_path): + return cache_path if cache_path.exists() else model_path + + self._convert_cache.make_room(self.get_size_fs(config, model_path, submodel_type)) + return self._convert_model(config, model_path, cache_path) + + def _needs_conversion(self, config: AnyModelConfig, model_path: Path, cache_path: Path) -> bool: + return False + + def _load_if_needed( + self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None + ) -> ModelLockerBase: + # TO DO: This is not thread safe! + try: + return self._ram_cache.get(config.key, submodel_type) + except IndexError: + pass + + model_variant = getattr(config, "repo_variant", None) + self._ram_cache.make_room(self.get_size_fs(config, model_path, submodel_type)) + + # This is where the model is actually loaded! + with skip_torch_weight_init(): + loaded_model = self._load_model(model_path, model_variant=model_variant, submodel_type=submodel_type) + + self._ram_cache.put( + config.key, + submodel_type=submodel_type, + model=loaded_model, + size=calc_model_size_by_data(loaded_model), + ) + + return self._ram_cache.get( + key=config.key, + submodel_type=submodel_type, + stats_name=":".join([config.base, config.type, config.name, (submodel_type or "")]), + ) + + def get_size_fs( + self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None + ) -> int: + """Get the size of the model on disk.""" + return calc_model_size_by_fs( + model_path=model_path, + subfolder=submodel_type.value if submodel_type else None, + variant=config.repo_variant if hasattr(config, "repo_variant") else None, + ) + + def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]: + return ConfigLoader.load_config(model_path, config_name=config_name) + + # TO DO: Add exception handling + def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type + if module in ["diffusers", "transformers"]: + res_type = sys.modules[module] + else: + res_type = sys.modules["diffusers"].pipelines + result: ModelMixin = getattr(res_type, class_name) + return result + + # TO DO: Add exception handling + def _get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin: + if submodel_type: + try: + config = self._load_diffusers_config(model_path, config_name="model_index.json") + module, class_name = config[submodel_type.value] + return self._hf_definition_to_type(module=module, class_name=class_name) + except KeyError as e: + raise InvalidModelConfigException( + f'The "{submodel_type}" submodel is not available for this model.' + ) from e + else: + try: + config = self._load_diffusers_config(model_path, config_name="config.json") + class_name = config.get("_class_name", None) + if class_name: + return self._hf_definition_to_type(module="diffusers", class_name=class_name) + if config.get("model_type", None) == "clip_vision_model": + class_name = config.get("architectures")[0] + return self._hf_definition_to_type(module="transformers", class_name=class_name) + if not class_name: + raise InvalidModelConfigException("Unable to decifer Load Class based on given config.json") + except KeyError as e: + raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e + + # This needs to be implemented in subclasses that handle checkpoints + def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path: + raise NotImplementedError + + # This needs to be implemented in the subclass + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + raise NotImplementedError diff --git a/invokeai/backend/model_manager/load/memory_snapshot.py b/invokeai/backend/model_manager/load/memory_snapshot.py new file mode 100644 index 00000000000..346f5dc4247 --- /dev/null +++ b/invokeai/backend/model_manager/load/memory_snapshot.py @@ -0,0 +1,100 @@ +import gc +from typing import Optional + +import psutil +import torch +from typing_extensions import Self + +from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2 + +GB = 2**30 # 1 GB + + +class MemorySnapshot: + """A snapshot of RAM and VRAM usage. All values are in bytes.""" + + def __init__(self, process_ram: int, vram: Optional[int], malloc_info: Optional[Struct_mallinfo2]): + """Initialize a MemorySnapshot. + + Most of the time, `MemorySnapshot` will be constructed with `MemorySnapshot.capture()`. + + Args: + process_ram (int): CPU RAM used by the current process. + vram (Optional[int]): VRAM used by torch. + malloc_info (Optional[Struct_mallinfo2]): Malloc info obtained from LibcUtil. + """ + self.process_ram = process_ram + self.vram = vram + self.malloc_info = malloc_info + + @classmethod + def capture(cls, run_garbage_collector: bool = True) -> Self: + """Capture and return a MemorySnapshot. + + Note: This function has significant overhead, particularly if `run_garbage_collector == True`. + + Args: + run_garbage_collector (bool, optional): If true, gc.collect() will be run before checking the process RAM + usage. Defaults to True. + + Returns: + MemorySnapshot + """ + if run_garbage_collector: + gc.collect() + + # According to the psutil docs (https://psutil.readthedocs.io/en/latest/#psutil.Process.memory_info), rss is + # supported on all platforms. + process_ram = psutil.Process().memory_info().rss + + if torch.cuda.is_available(): + vram = torch.cuda.memory_allocated() + else: + # TODO: We could add support for mps.current_allocated_memory() as well. Leaving out for now until we have + # time to test it properly. + vram = None + + try: + malloc_info = LibcUtil().mallinfo2() # type: ignore + except (OSError, AttributeError): + # OSError: This is expected in environments that do not have the 'libc.so.6' shared library. + # AttributeError: This is expected in environments that have `libc.so.6` but do not have the `mallinfo2` (e.g. glibc < 2.33) + # TODO: Does `mallinfo` work? + malloc_info = None + + return cls(process_ram, vram, malloc_info) + + +def get_pretty_snapshot_diff(snapshot_1: Optional[MemorySnapshot], snapshot_2: Optional[MemorySnapshot]) -> str: + """Get a pretty string describing the difference between two `MemorySnapshot`s.""" + + def get_msg_line(prefix: str, val1: int, val2: int) -> str: + diff = val2 - val1 + return f"{prefix: <30} ({(diff/GB):+5.3f}): {(val1/GB):5.3f}GB -> {(val2/GB):5.3f}GB\n" + + msg = "" + + if snapshot_1 is None or snapshot_2 is None: + return msg + + msg += get_msg_line("Process RAM", snapshot_1.process_ram, snapshot_2.process_ram) + + if snapshot_1.malloc_info is not None and snapshot_2.malloc_info is not None: + msg += get_msg_line("libc mmap allocated", snapshot_1.malloc_info.hblkhd, snapshot_2.malloc_info.hblkhd) + + msg += get_msg_line("libc arena used", snapshot_1.malloc_info.uordblks, snapshot_2.malloc_info.uordblks) + + msg += get_msg_line("libc arena free", snapshot_1.malloc_info.fordblks, snapshot_2.malloc_info.fordblks) + + libc_total_allocated_1 = snapshot_1.malloc_info.arena + snapshot_1.malloc_info.hblkhd + libc_total_allocated_2 = snapshot_2.malloc_info.arena + snapshot_2.malloc_info.hblkhd + msg += get_msg_line("libc total allocated", libc_total_allocated_1, libc_total_allocated_2) + + libc_total_used_1 = snapshot_1.malloc_info.uordblks + snapshot_1.malloc_info.hblkhd + libc_total_used_2 = snapshot_2.malloc_info.uordblks + snapshot_2.malloc_info.hblkhd + msg += get_msg_line("libc total used", libc_total_used_1, libc_total_used_2) + + if snapshot_1.vram is not None and snapshot_2.vram is not None: + msg += get_msg_line("VRAM", snapshot_1.vram, snapshot_2.vram) + + return "\n" + msg if len(msg) > 0 else msg diff --git a/invokeai/backend/model_manager/load/model_cache/__init__.py b/invokeai/backend/model_manager/load/model_cache/__init__.py new file mode 100644 index 00000000000..32c682d0424 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/__init__.py @@ -0,0 +1,6 @@ +"""Init file for ModelCache.""" + +from .model_cache_base import ModelCacheBase, CacheStats # noqa F401 +from .model_cache_default import ModelCache # noqa F401 + +_all__ = ["ModelCacheBase", "ModelCache", "CacheStats"] diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py new file mode 100644 index 00000000000..4a4a3c7d299 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py @@ -0,0 +1,188 @@ +# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team +# TODO: Add Stalker's proper name to copyright +""" +Manage a RAM cache of diffusion/transformer models for fast switching. +They are moved between GPU VRAM and CPU RAM as necessary. If the cache +grows larger than a preset maximum, then the least recently used +model will be cleared and (re)loaded from disk when next needed. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from logging import Logger +from typing import Dict, Generic, Optional, TypeVar + +import torch + +from invokeai.backend.model_manager.config import AnyModel, SubModelType + + +class ModelLockerBase(ABC): + """Base class for the model locker used by the loader.""" + + @abstractmethod + def lock(self) -> AnyModel: + """Lock the contained model and move it into VRAM.""" + pass + + @abstractmethod + def unlock(self) -> None: + """Unlock the contained model, and remove it from VRAM.""" + pass + + @property + @abstractmethod + def model(self) -> AnyModel: + """Return the model.""" + pass + + +T = TypeVar("T") + + +@dataclass +class CacheRecord(Generic[T]): + """Elements of the cache.""" + + key: str + model: T + size: int + loaded: bool = False + _locks: int = 0 + + def lock(self) -> None: + """Lock this record.""" + self._locks += 1 + + def unlock(self) -> None: + """Unlock this record.""" + self._locks -= 1 + assert self._locks >= 0 + + @property + def locked(self) -> bool: + """Return true if record is locked.""" + return self._locks > 0 + + +@dataclass +class CacheStats(object): + """Collect statistics on cache performance.""" + + hits: int = 0 # cache hits + misses: int = 0 # cache misses + high_watermark: int = 0 # amount of cache used + in_cache: int = 0 # number of models in cache + cleared: int = 0 # number of models cleared to make space + cache_size: int = 0 # total size of cache + loaded_model_sizes: Dict[str, int] = field(default_factory=dict) + + +class ModelCacheBase(ABC, Generic[T]): + """Virtual base class for RAM model cache.""" + + @property + @abstractmethod + def storage_device(self) -> torch.device: + """Return the storage device (e.g. "CPU" for RAM).""" + pass + + @property + @abstractmethod + def execution_device(self) -> torch.device: + """Return the exection device (e.g. "cuda" for VRAM).""" + pass + + @property + @abstractmethod + def lazy_offloading(self) -> bool: + """Return true if the cache is configured to lazily offload models in VRAM.""" + pass + + @property + @abstractmethod + def max_cache_size(self) -> float: + """Return true if the cache is configured to lazily offload models in VRAM.""" + pass + + @abstractmethod + def offload_unlocked_models(self, size_required: int) -> None: + """Offload from VRAM any models not actively in use.""" + pass + + @abstractmethod + def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], device: torch.device) -> None: + """Move model into the indicated device.""" + pass + + @property + @abstractmethod + def stats(self) -> CacheStats: + """Return collected CacheStats object.""" + pass + + @stats.setter + @abstractmethod + def stats(self, stats: CacheStats) -> None: + """Set the CacheStats object for collectin cache statistics.""" + pass + + @property + @abstractmethod + def logger(self) -> Logger: + """Return the logger used by the cache.""" + pass + + @abstractmethod + def make_room(self, size: int) -> None: + """Make enough room in the cache to accommodate a new model of indicated size.""" + pass + + @abstractmethod + def put( + self, + key: str, + model: T, + size: int, + submodel_type: Optional[SubModelType] = None, + ) -> None: + """Store model under key and optional submodel_type.""" + pass + + @abstractmethod + def get( + self, + key: str, + submodel_type: Optional[SubModelType] = None, + stats_name: Optional[str] = None, + ) -> ModelLockerBase: + """ + Retrieve model using key and optional submodel_type. + + :param key: Opaque model key + :param submodel_type: Type of the submodel to fetch + :param stats_name: A human-readable id for the model for the purposes of + stats reporting. + + This may raise an IndexError if the model is not in the cache. + """ + pass + + @abstractmethod + def exists( + self, + key: str, + submodel_type: Optional[SubModelType] = None, + ) -> bool: + """Return true if the model identified by key and submodel_type is in the cache.""" + pass + + @abstractmethod + def cache_size(self) -> int: + """Get the total size of the models currently cached.""" + pass + + @abstractmethod + def print_cuda_stats(self) -> None: + """Log debugging information on CUDA usage.""" + pass diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py new file mode 100644 index 00000000000..02ce1266c75 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -0,0 +1,407 @@ +# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team +# TODO: Add Stalker's proper name to copyright +""" +Manage a RAM cache of diffusion/transformer models for fast switching. +They are moved between GPU VRAM and CPU RAM as necessary. If the cache +grows larger than a preset maximum, then the least recently used +model will be cleared and (re)loaded from disk when next needed. + +The cache returns context manager generators designed to load the +model into the GPU within the context, and unload outside the +context. Use like this: + + cache = ModelCache(max_cache_size=7.5) + with cache.get_model('runwayml/stable-diffusion-1-5') as SD1, + cache.get_model('stabilityai/stable-diffusion-2') as SD2: + do_something_in_GPU(SD1,SD2) + + +""" + +import gc +import logging +import math +import sys +import time +from contextlib import suppress +from logging import Logger +from typing import Dict, List, Optional + +import torch + +from invokeai.backend.model_manager import AnyModel, SubModelType +from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff +from invokeai.backend.util.devices import choose_torch_device +from invokeai.backend.util.logging import InvokeAILogger + +from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase +from .model_locker import ModelLocker + +if choose_torch_device() == torch.device("mps"): + from torch import mps + +# Maximum size of the cache, in gigs +# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously +DEFAULT_MAX_CACHE_SIZE = 6.0 + +# amount of GPU memory to hold in reserve for use by generations (GB) +DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75 + +# actual size of a gig +GIG = 1073741824 + +# Size of a MB in bytes. +MB = 2**20 + + +class ModelCache(ModelCacheBase[AnyModel]): + """Implementation of ModelCacheBase.""" + + def __init__( + self, + max_cache_size: float = DEFAULT_MAX_CACHE_SIZE, + max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE, + execution_device: torch.device = torch.device("cuda"), + storage_device: torch.device = torch.device("cpu"), + precision: torch.dtype = torch.float16, + sequential_offload: bool = False, + lazy_offloading: bool = True, + sha_chunksize: int = 16777216, + log_memory_usage: bool = False, + logger: Optional[Logger] = None, + ): + """ + Initialize the model RAM cache. + + :param max_cache_size: Maximum size of the RAM cache [6.0 GB] + :param execution_device: Torch device to load active model into [torch.device('cuda')] + :param storage_device: Torch device to save inactive model in [torch.device('cpu')] + :param precision: Precision for loaded models [torch.float16] + :param lazy_offloading: Keep model in VRAM until another model needs to be loaded + :param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially + :param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache + operation, and the result will be logged (at debug level). There is a time cost to capturing the memory + snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's + behaviour. + """ + # allow lazy offloading only when vram cache enabled + self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0 + self._precision: torch.dtype = precision + self._max_cache_size: float = max_cache_size + self._max_vram_cache_size: float = max_vram_cache_size + self._execution_device: torch.device = execution_device + self._storage_device: torch.device = storage_device + self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__) + self._log_memory_usage = log_memory_usage or self._logger.level == logging.DEBUG + # used for stats collection + self._stats: Optional[CacheStats] = None + + self._cached_models: Dict[str, CacheRecord[AnyModel]] = {} + self._cache_stack: List[str] = [] + + @property + def logger(self) -> Logger: + """Return the logger used by the cache.""" + return self._logger + + @property + def lazy_offloading(self) -> bool: + """Return true if the cache is configured to lazily offload models in VRAM.""" + return self._lazy_offloading + + @property + def storage_device(self) -> torch.device: + """Return the storage device (e.g. "CPU" for RAM).""" + return self._storage_device + + @property + def execution_device(self) -> torch.device: + """Return the exection device (e.g. "cuda" for VRAM).""" + return self._execution_device + + @property + def max_cache_size(self) -> float: + """Return the cap on cache size.""" + return self._max_cache_size + + @property + def stats(self) -> Optional[CacheStats]: + """Return collected CacheStats object.""" + return self._stats + + @stats.setter + def stats(self, stats: CacheStats) -> None: + """Set the CacheStats object for collectin cache statistics.""" + self._stats = stats + + def cache_size(self) -> int: + """Get the total size of the models currently cached.""" + total = 0 + for cache_record in self._cached_models.values(): + total += cache_record.size + return total + + def exists( + self, + key: str, + submodel_type: Optional[SubModelType] = None, + ) -> bool: + """Return true if the model identified by key and submodel_type is in the cache.""" + key = self._make_cache_key(key, submodel_type) + return key in self._cached_models + + def put( + self, + key: str, + model: AnyModel, + size: int, + submodel_type: Optional[SubModelType] = None, + ) -> None: + """Store model under key and optional submodel_type.""" + key = self._make_cache_key(key, submodel_type) + assert key not in self._cached_models + + cache_record = CacheRecord(key, model, size) + self._cached_models[key] = cache_record + self._cache_stack.append(key) + + def get( + self, + key: str, + submodel_type: Optional[SubModelType] = None, + stats_name: Optional[str] = None, + ) -> ModelLockerBase: + """ + Retrieve model using key and optional submodel_type. + + :param key: Opaque model key + :param submodel_type: Type of the submodel to fetch + :param stats_name: A human-readable id for the model for the purposes of + stats reporting. + + This may raise an IndexError if the model is not in the cache. + """ + key = self._make_cache_key(key, submodel_type) + if key in self._cached_models: + if self.stats: + self.stats.hits += 1 + else: + if self.stats: + self.stats.misses += 1 + raise IndexError(f"The model with key {key} is not in the cache.") + + cache_entry = self._cached_models[key] + + # more stats + if self.stats: + stats_name = stats_name or key + self.stats.cache_size = int(self._max_cache_size * GIG) + self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size()) + self.stats.in_cache = len(self._cached_models) + self.stats.loaded_model_sizes[stats_name] = max( + self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size + ) + + # this moves the entry to the top (right end) of the stack + with suppress(Exception): + self._cache_stack.remove(key) + self._cache_stack.append(key) + return ModelLocker( + cache=self, + cache_entry=cache_entry, + ) + + def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]: + if self._log_memory_usage: + return MemorySnapshot.capture() + return None + + def _make_cache_key(self, model_key: str, submodel_type: Optional[SubModelType] = None) -> str: + if submodel_type: + return f"{model_key}:{submodel_type.value}" + else: + return model_key + + def offload_unlocked_models(self, size_required: int) -> None: + """Move any unused models from VRAM.""" + reserved = self._max_vram_cache_size * GIG + vram_in_use = torch.cuda.memory_allocated() + size_required + self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB") + for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size): + if vram_in_use <= reserved: + break + if not cache_entry.loaded: + continue + if not cache_entry.locked: + self.move_model_to_device(cache_entry, self.storage_device) + cache_entry.loaded = False + vram_in_use = torch.cuda.memory_allocated() + size_required + self.logger.debug( + f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB" + ) + + torch.cuda.empty_cache() + if choose_torch_device() == torch.device("mps"): + mps.empty_cache() + + def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None: + """Move model into the indicated device.""" + # These attributes are not in the base ModelMixin class but in various derived classes. + # Some models don't have these attributes, in which case they run in RAM/CPU. + self.logger.debug(f"Called to move {cache_entry.key} to {target_device}") + if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")): + return + + source_device = cache_entry.model.device + + # Note: We compare device types only so that 'cuda' == 'cuda:0'. + # This would need to be revised to support multi-GPU. + if torch.device(source_device).type == torch.device(target_device).type: + return + + start_model_to_time = time.time() + snapshot_before = self._capture_memory_snapshot() + cache_entry.model.to(target_device) + snapshot_after = self._capture_memory_snapshot() + end_model_to_time = time.time() + self.logger.debug( + f"Moved model '{cache_entry.key}' from {source_device} to" + f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s." + f"Estimated model size: {(cache_entry.size/GIG):.3f} GB." + f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" + ) + + if ( + snapshot_before is not None + and snapshot_after is not None + and snapshot_before.vram is not None + and snapshot_after.vram is not None + ): + vram_change = abs(snapshot_before.vram - snapshot_after.vram) + + # If the estimated model size does not match the change in VRAM, log a warning. + if not math.isclose( + vram_change, + cache_entry.size, + rel_tol=0.1, + abs_tol=10 * MB, + ): + self.logger.debug( + f"Moving model '{cache_entry.key}' from {source_device} to" + f" {target_device} caused an unexpected change in VRAM usage. The model's" + " estimated size may be incorrect. Estimated model size:" + f" {(cache_entry.size/GIG):.3f} GB.\n" + f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" + ) + + def print_cuda_stats(self) -> None: + """Log CUDA diagnostics.""" + vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG) + ram = "%4.2fG" % (self.cache_size() / GIG) + + in_ram_models = 0 + in_vram_models = 0 + locked_in_vram_models = 0 + for cache_record in self._cached_models.values(): + if hasattr(cache_record.model, "device"): + if cache_record.model.device == self.storage_device: + in_ram_models += 1 + else: + in_vram_models += 1 + if cache_record.locked: + locked_in_vram_models += 1 + + self.logger.debug( + f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) =" + f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})" + ) + + def make_room(self, model_size: int) -> None: + """Make enough room in the cache to accommodate a new model of indicated size.""" + # calculate how much memory this model will require + # multiplier = 2 if self.precision==torch.float32 else 1 + bytes_needed = model_size + maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes + current_size = self.cache_size() + + if current_size + bytes_needed > maximum_size: + self.logger.debug( + f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional" + f" {(bytes_needed/GIG):.2f} GB" + ) + + self.logger.debug(f"Before making_room: cached_models={len(self._cached_models)}") + + pos = 0 + models_cleared = 0 + while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack): + model_key = self._cache_stack[pos] + cache_entry = self._cached_models[model_key] + + refs = sys.getrefcount(cache_entry.model) + + # HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly + # going against the advice in the Python docs by using `gc.get_referrers(...)` in this way: + # https://docs.python.org/3/library/gc.html#gc.get_referrers + + # manualy clear local variable references of just finished function calls + # for some reason python don't want to collect it even by gc.collect() immidiately + if refs > 2: + while True: + cleared = False + for referrer in gc.get_referrers(cache_entry.model): + if type(referrer).__name__ == "frame": + # RuntimeError: cannot clear an executing frame + with suppress(RuntimeError): + referrer.clear() + cleared = True + # break + + # repeat if referrers changes(due to frame clear), else exit loop + if cleared: + gc.collect() + else: + break + + device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None + self.logger.debug( + f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}," + f" refs: {refs}" + ) + + # Expected refs: + # 1 from cache_entry + # 1 from getrefcount function + # 1 from onnx runtime object + if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2): + self.logger.debug( + f"Removing {model_key} from RAM cache to free at least {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)" + ) + current_size -= cache_entry.size + models_cleared += 1 + del self._cache_stack[pos] + del self._cached_models[model_key] + del cache_entry + + else: + pos += 1 + + if models_cleared > 0: + # There would likely be some 'garbage' to be collected regardless of whether a model was cleared or not, but + # there is a significant time cost to calling `gc.collect()`, so we want to use it sparingly. (The time cost + # is high even if no garbage gets collected.) + # + # Calling gc.collect(...) when a model is cleared seems like a good middle-ground: + # - If models had to be cleared, it's a signal that we are close to our memory limit. + # - If models were cleared, there's a good chance that there's a significant amount of garbage to be + # collected. + # + # Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up + # immediately when their reference count hits 0. + gc.collect() + + torch.cuda.empty_cache() + if choose_torch_device() == torch.device("mps"): + mps.empty_cache() + + self.logger.debug(f"After making room: cached_models={len(self._cached_models)}") diff --git a/invokeai/backend/model_manager/load/model_cache/model_locker.py b/invokeai/backend/model_manager/load/model_cache/model_locker.py new file mode 100644 index 00000000000..7a5fdd4284b --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/model_locker.py @@ -0,0 +1,59 @@ +""" +Base class and implementation of a class that moves models in and out of VRAM. +""" + +from invokeai.backend.model_manager import AnyModel + +from .model_cache_base import CacheRecord, ModelCacheBase, ModelLockerBase + + +class ModelLocker(ModelLockerBase): + """Internal class that mediates movement in and out of GPU.""" + + def __init__(self, cache: ModelCacheBase[AnyModel], cache_entry: CacheRecord[AnyModel]): + """ + Initialize the model locker. + + :param cache: The ModelCache object + :param cache_entry: The entry in the model cache + """ + self._cache = cache + self._cache_entry = cache_entry + + @property + def model(self) -> AnyModel: + """Return the model without moving it around.""" + return self._cache_entry.model + + def lock(self) -> AnyModel: + """Move the model into the execution device (GPU) and lock it.""" + if not hasattr(self.model, "to"): + return self.model + + # NOTE that the model has to have the to() method in order for this code to move it into GPU! + self._cache_entry.lock() + + try: + if self._cache.lazy_offloading: + self._cache.offload_unlocked_models(self._cache_entry.size) + + self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device) + self._cache_entry.loaded = True + + self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}") + self._cache.print_cuda_stats() + + except Exception: + self._cache_entry.unlock() + raise + return self.model + + def unlock(self) -> None: + """Call upon exit from context.""" + if not hasattr(self.model, "to"): + return + + self._cache_entry.unlock() + if not self._cache.lazy_offloading: + self._cache.offload_unlocked_models(self._cache_entry.size) + self._cache.print_cuda_stats() diff --git a/invokeai/backend/model_manager/load/model_loaders/__init__.py b/invokeai/backend/model_manager/load/model_loaders/__init__.py new file mode 100644 index 00000000000..962cba54811 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/__init__.py @@ -0,0 +1,3 @@ +""" +Init file for model_loaders. +""" diff --git a/invokeai/backend/model_manager/load/model_loaders/controlnet.py b/invokeai/backend/model_manager/load/model_loaders/controlnet.py new file mode 100644 index 00000000000..d446d079336 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/controlnet.py @@ -0,0 +1,62 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Class for ControlNet model loading in InvokeAI.""" + +from pathlib import Path + +import safetensors +import torch + +from invokeai.backend.model_manager import ( + AnyModelConfig, + BaseModelType, + ModelFormat, + ModelType, +) +from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers +from invokeai.backend.model_manager.load.load_base import AnyModelLoader + +from .generic_diffusers import GenericDiffusersLoader + + +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers) +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint) +class ControlnetLoader(GenericDiffusersLoader): + """Class to load ControlNet models.""" + + def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool: + if config.format != ModelFormat.Checkpoint: + return False + elif ( + dest_path.exists() + and (dest_path / "config.json").stat().st_mtime >= (config.last_modified or 0.0) + and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime + ): + return False + else: + return True + + def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path: + if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}: + raise Exception(f"Vae conversion not supported for model type: {config.base}") + else: + assert hasattr(config, "config") + config_file = config.config + + if model_path.suffix == ".safetensors": + checkpoint = safetensors.torch.load_file(model_path, device="cpu") + else: + checkpoint = torch.load(model_path, map_location="cpu") + + # sometimes weights are hidden under "state_dict", and sometimes not + if "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + convert_controlnet_to_diffusers( + model_path, + output_path, + original_config_file=self._app_config.root_path / config_file, + image_size=512, + scan_needed=True, + from_safetensors=model_path.suffix == ".safetensors", + ) + return output_path diff --git a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py new file mode 100644 index 00000000000..114e317f3c6 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py @@ -0,0 +1,36 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Class for simple diffusers model loading in InvokeAI.""" + +from pathlib import Path +from typing import Optional + +from invokeai.backend.model_manager import ( + AnyModel, + BaseModelType, + ModelFormat, + ModelRepoVariant, + ModelType, + SubModelType, +) + +from ..load_base import AnyModelLoader +from ..load_default import ModelLoader + + +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers) +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers) +class GenericDiffusersLoader(ModelLoader): + """Class to load simple diffusers models.""" + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + model_class = self._get_hf_load_class(model_path) + if submodel_type is not None: + raise Exception(f"There are no submodels in models of type {model_class}") + variant = model_variant.value if model_variant else None + result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant) # type: ignore + return result diff --git a/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py b/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py new file mode 100644 index 00000000000..27ced41c1e9 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py @@ -0,0 +1,39 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Class for IP Adapter model loading in InvokeAI.""" + +from pathlib import Path +from typing import Optional + +import torch + +from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter +from invokeai.backend.model_manager import ( + AnyModel, + BaseModelType, + ModelFormat, + ModelRepoVariant, + ModelType, + SubModelType, +) +from invokeai.backend.model_manager.load.load_base import AnyModelLoader +from invokeai.backend.model_manager.load.load_default import ModelLoader + + +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI) +class IPAdapterInvokeAILoader(ModelLoader): + """Class to load IP Adapter diffusers models.""" + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if submodel_type is not None: + raise ValueError("There are no submodels in an IP-Adapter model.") + model = build_ip_adapter( + ip_adapter_ckpt_path=model_path / "ip_adapter.bin", + device=torch.device("cpu"), + dtype=self._torch_dtype, + ) + return model diff --git a/invokeai/backend/model_manager/load/model_loaders/lora.py b/invokeai/backend/model_manager/load/model_loaders/lora.py new file mode 100644 index 00000000000..d8e5f920e24 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/lora.py @@ -0,0 +1,78 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Class for LoRA model loading in InvokeAI.""" + + +from logging import Logger +from pathlib import Path +from typing import Optional, Tuple + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend.embeddings.lora import LoRAModelRaw +from invokeai.backend.model_manager import ( + AnyModel, + AnyModelConfig, + BaseModelType, + ModelFormat, + ModelRepoVariant, + ModelType, + SubModelType, +) +from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase +from invokeai.backend.model_manager.load.load_base import AnyModelLoader +from invokeai.backend.model_manager.load.load_default import ModelLoader +from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase + + +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Diffusers) +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Lycoris) +class LoraLoader(ModelLoader): + """Class to load LoRA models.""" + + # We cheat a little bit to get access to the model base + def __init__( + self, + app_config: InvokeAIAppConfig, + logger: Logger, + ram_cache: ModelCacheBase[AnyModel], + convert_cache: ModelConvertCacheBase, + ): + """Initialize the loader.""" + super().__init__(app_config, logger, ram_cache, convert_cache) + self._model_base: Optional[BaseModelType] = None + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if submodel_type is not None: + raise ValueError("There are no submodels in a LoRA model.") + assert self._model_base is not None + model = LoRAModelRaw.from_checkpoint( + file_path=model_path, + dtype=self._torch_dtype, + base_model=self._model_base, + ) + return model + + # override + def _get_model_path( + self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None + ) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]: + self._model_base = ( + config.base + ) # cheating a little - we remember this variable for using in the subsequent call to _load_model() + + model_base_path = self._app_config.models_path + model_path = model_base_path / config.path + + if config.format == ModelFormat.Diffusers: + for ext in ["safetensors", "bin"]: # return path to the safetensors file inside the folder + path = model_base_path / config.path / f"pytorch_lora_weights.{ext}" + if path.exists(): + model_path = path + break + + result = model_path.resolve(), config, submodel_type + return result diff --git a/invokeai/backend/model_manager/load/model_loaders/onnx.py b/invokeai/backend/model_manager/load/model_loaders/onnx.py new file mode 100644 index 00000000000..935a6b7c953 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/onnx.py @@ -0,0 +1,41 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Class for Onnx model loading in InvokeAI.""" + +# This should work the same as Stable Diffusion pipelines +from pathlib import Path +from typing import Optional + +from invokeai.backend.model_manager import ( + AnyModel, + BaseModelType, + ModelFormat, + ModelRepoVariant, + ModelType, + SubModelType, +) +from invokeai.backend.model_manager.load.load_base import AnyModelLoader +from invokeai.backend.model_manager.load.load_default import ModelLoader + + +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Onnx) +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive) +class OnnyxDiffusersModel(ModelLoader): + """Class to load onnx models.""" + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if not submodel_type is not None: + raise Exception("A submodel type must be provided when loading onnx pipelines.") + load_class = self._get_hf_load_class(model_path, submodel_type) + variant = model_variant.value if model_variant else None + model_path = model_path / submodel_type.value + result: AnyModel = load_class.from_pretrained( + model_path, + torch_dtype=self._torch_dtype, + variant=variant, + ) # type: ignore + return result diff --git a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py new file mode 100644 index 00000000000..23b4e1fccd6 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py @@ -0,0 +1,93 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Class for StableDiffusion model loading in InvokeAI.""" + + +from pathlib import Path +from typing import Optional + +from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline + +from invokeai.backend.model_manager import ( + AnyModel, + AnyModelConfig, + BaseModelType, + ModelFormat, + ModelRepoVariant, + ModelType, + ModelVariantType, + SubModelType, +) +from invokeai.backend.model_manager.config import MainCheckpointConfig +from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers +from invokeai.backend.model_manager.load.load_base import AnyModelLoader +from invokeai.backend.model_manager.load.load_default import ModelLoader + + +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Diffusers) +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Checkpoint) +class StableDiffusionDiffusersModel(ModelLoader): + """Class to load main models.""" + + model_base_to_model_type = { + BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder", + BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder", + BaseModelType.StableDiffusionXL: "SDXL", + BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner", + } + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if not submodel_type is not None: + raise Exception("A submodel type must be provided when loading main pipelines.") + load_class = self._get_hf_load_class(model_path, submodel_type) + variant = model_variant.value if model_variant else None + model_path = model_path / submodel_type.value + result: AnyModel = load_class.from_pretrained( + model_path, + torch_dtype=self._torch_dtype, + variant=variant, + ) # type: ignore + return result + + def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool: + if config.format != ModelFormat.Checkpoint: + return False + elif ( + dest_path.exists() + and (dest_path / "model_index.json").stat().st_mtime >= (config.last_modified or 0.0) + and (dest_path / "model_index.json").stat().st_mtime >= model_path.stat().st_mtime + ): + return False + else: + return True + + def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path: + assert isinstance(config, MainCheckpointConfig) + variant = config.variant + base = config.base + pipeline_class = ( + StableDiffusionInpaintPipeline if variant == ModelVariantType.Inpaint else StableDiffusionPipeline + ) + + config_file = config.config + + self._logger.info(f"Converting {model_path} to diffusers format") + convert_ckpt_to_diffusers( + model_path, + output_path, + model_type=self.model_base_to_model_type[base], + model_version=base, + model_variant=variant, + original_config_file=self._app_config.root_path / config_file, + extract_ema=True, + scan_needed=True, + pipeline_class=pipeline_class, + from_safetensors=model_path.suffix == ".safetensors", + precision=self._torch_dtype, + load_safety_checker=False, + ) + return output_path diff --git a/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py b/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py new file mode 100644 index 00000000000..6635f6b43fe --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py @@ -0,0 +1,55 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Class for TI model loading in InvokeAI.""" + + +from pathlib import Path +from typing import Optional, Tuple + +from invokeai.backend.embeddings.textual_inversion import TextualInversionModelRaw +from invokeai.backend.model_manager import ( + AnyModel, + AnyModelConfig, + BaseModelType, + ModelFormat, + ModelRepoVariant, + ModelType, + SubModelType, +) +from invokeai.backend.model_manager.load.load_base import AnyModelLoader +from invokeai.backend.model_manager.load.load_default import ModelLoader + + +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFile) +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFolder) +class TextualInversionLoader(ModelLoader): + """Class to load TI models.""" + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if submodel_type is not None: + raise ValueError("There are no submodels in a TI model.") + model = TextualInversionModelRaw.from_checkpoint( + file_path=model_path, + dtype=self._torch_dtype, + ) + return model + + # override + def _get_model_path( + self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None + ) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]: + model_path = self._app_config.models_path / config.path + + if config.format == ModelFormat.EmbeddingFolder: + path = model_path / "learned_embeds.bin" + else: + path = model_path + + if not path.exists(): + raise OSError(f"The embedding file at {path} was not found") + + return path, config, submodel_type diff --git a/invokeai/backend/model_manager/load/model_loaders/vae.py b/invokeai/backend/model_manager/load/model_loaders/vae.py new file mode 100644 index 00000000000..3983ea75950 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/vae.py @@ -0,0 +1,68 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Class for VAE model loading in InvokeAI.""" + +from pathlib import Path + +import safetensors +import torch +from omegaconf import DictConfig, OmegaConf + +from invokeai.backend.model_manager import ( + AnyModelConfig, + BaseModelType, + ModelFormat, + ModelType, +) +from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers +from invokeai.backend.model_manager.load.load_base import AnyModelLoader + +from .generic_diffusers import GenericDiffusersLoader + + +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers) +@AnyModelLoader.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint) +@AnyModelLoader.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint) +class VaeLoader(GenericDiffusersLoader): + """Class to load VAE models.""" + + def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool: + if config.format != ModelFormat.Checkpoint: + return False + elif ( + dest_path.exists() + and (dest_path / "config.json").stat().st_mtime >= (config.last_modified or 0.0) + and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime + ): + return False + else: + return True + + def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path: + # TO DO: check whether sdxl VAE models convert. + if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}: + raise Exception(f"Vae conversion not supported for model type: {config.base}") + else: + config_file = ( + "v1-inference.yaml" if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml" + ) + + if model_path.suffix == ".safetensors": + checkpoint = safetensors.torch.load_file(model_path, device="cpu") + else: + checkpoint = torch.load(model_path, map_location="cpu") + + # sometimes weights are hidden under "state_dict", and sometimes not + if "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + ckpt_config = OmegaConf.load(self._app_config.legacy_conf_path / config_file) + assert isinstance(ckpt_config, DictConfig) + + vae_model = convert_ldm_vae_to_diffusers( + checkpoint=checkpoint, + vae_config=ckpt_config, + image_size=512, + ) + vae_model.to(self._torch_dtype) # set precision appropriately + vae_model.save_pretrained(output_path, safe_serialization=True) + return output_path diff --git a/invokeai/backend/model_manager/load/model_util.py b/invokeai/backend/model_manager/load/model_util.py new file mode 100644 index 00000000000..c55eee48fa5 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_util.py @@ -0,0 +1,113 @@ +# Copyright (c) 2024 The InvokeAI Development Team +"""Various utility functions needed by the loader and caching system.""" + +import json +from pathlib import Path +from typing import Optional + +import torch +from diffusers import DiffusionPipeline + +from invokeai.backend.model_manager.config import AnyModel +from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel + + +def calc_model_size_by_data(model: AnyModel) -> int: + """Get size of a model in memory in bytes.""" + if isinstance(model, DiffusionPipeline): + return _calc_pipeline_by_data(model) + elif isinstance(model, torch.nn.Module): + return _calc_model_by_data(model) + elif isinstance(model, IAIOnnxRuntimeModel): + return _calc_onnx_model_by_data(model) + else: + return 0 + + +def _calc_pipeline_by_data(pipeline: DiffusionPipeline) -> int: + res = 0 + assert hasattr(pipeline, "components") + for submodel_key in pipeline.components.keys(): + submodel = getattr(pipeline, submodel_key) + if submodel is not None and isinstance(submodel, torch.nn.Module): + res += _calc_model_by_data(submodel) + return res + + +def _calc_model_by_data(model: torch.nn.Module) -> int: + mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()]) + mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()]) + mem: int = mem_params + mem_bufs # in bytes + return mem + + +def _calc_onnx_model_by_data(model: IAIOnnxRuntimeModel) -> int: + tensor_size = model.tensors.size() * 2 # The session doubles this + mem = tensor_size # in bytes + return mem + + +def calc_model_size_by_fs(model_path: Path, subfolder: Optional[str] = None, variant: Optional[str] = None) -> int: + """Estimate the size of a model on disk in bytes.""" + if model_path.is_file(): + return model_path.stat().st_size + + if subfolder is not None: + model_path = model_path / subfolder + + # this can happen when, for example, the safety checker is not downloaded. + if not model_path.exists(): + return 0 + + all_files = [f for f in model_path.iterdir() if (model_path / f).is_file()] + + fp16_files = {f for f in all_files if ".fp16." in f.name or ".fp16-" in f.name} + bit8_files = {f for f in all_files if ".8bit." in f.name or ".8bit-" in f.name} + other_files = set(all_files) - fp16_files - bit8_files + + if not variant: # ModelRepoVariant.DEFAULT evaluates to empty string for compatability with HF + files = other_files + elif variant == "fp16": + files = fp16_files + elif variant == "8bit": + files = bit8_files + else: + raise NotImplementedError(f"Unknown variant: {variant}") + + # try read from index if exists + index_postfix = ".index.json" + if variant is not None: + index_postfix = f".index.{variant}.json" + + for file in files: + if not file.name.endswith(index_postfix): + continue + try: + with open(model_path / file, "r") as f: + index_data = json.loads(f.read()) + return int(index_data["metadata"]["total_size"]) + except Exception: + pass + + # calculate files size if there is no index file + formats = [ + (".safetensors",), # safetensors + (".bin",), # torch + (".onnx", ".pb"), # onnx + (".msgpack",), # flax + (".ckpt",), # tf + (".h5",), # tf2 + ] + + for file_format in formats: + model_files = [f for f in files if f.suffix in file_format] + if len(model_files) == 0: + continue + + model_size = 0 + for model_file in model_files: + file_stats = (model_path / model_file).stat() + model_size += file_stats.st_size + return model_size + + return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu diff --git a/invokeai/backend/model_manager/load/optimizations.py b/invokeai/backend/model_manager/load/optimizations.py new file mode 100644 index 00000000000..a46d262175f --- /dev/null +++ b/invokeai/backend/model_manager/load/optimizations.py @@ -0,0 +1,30 @@ +from contextlib import contextmanager + +import torch + + +def _no_op(*args, **kwargs): + pass + + +@contextmanager +def skip_torch_weight_init(): + """A context manager that monkey-patches several of the common torch layers (torch.nn.Linear, torch.nn.Conv1d, etc.) + to skip weight initialization. + + By default, `torch.nn.Linear` and `torch.nn.ConvNd` layers initialize their weights (according to a particular + distribution) when __init__ is called. This weight initialization step can take a significant amount of time, and is + completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager + monkey-patches common torch layers to skip the weight initialization step. + """ + torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding] + saved_functions = [m.reset_parameters for m in torch_modules] + + try: + for torch_module in torch_modules: + torch_module.reset_parameters = _no_op + + yield None + finally: + for torch_module, saved_function in zip(torch_modules, saved_functions, strict=True): + torch_module.reset_parameters = saved_function diff --git a/invokeai/backend/model_manager/metadata/fetch/civitai.py b/invokeai/backend/model_manager/metadata/fetch/civitai.py index 6e41d6f11b2..7991f6a7489 100644 --- a/invokeai/backend/model_manager/metadata/fetch/civitai.py +++ b/invokeai/backend/model_manager/metadata/fetch/civitai.py @@ -32,6 +32,8 @@ from pydantic.networks import AnyHttpUrl from requests.sessions import Session +from invokeai.backend.model_manager import ModelRepoVariant + from ..metadata_base import ( AnyModelRepoMetadata, CivitaiMetadata, @@ -82,10 +84,13 @@ def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata: return self.from_civitai_versionid(int(version_id)) raise UnknownMetadataException("The url '{url}' does not match any known Civitai URL patterns") - def from_id(self, id: str) -> AnyModelRepoMetadata: + def from_id(self, id: str, variant: Optional[ModelRepoVariant] = None) -> AnyModelRepoMetadata: """ Given a Civitai model version ID, return a ModelRepoMetadata object. + :param id: An ID. + :param variant: A model variant from the ModelRepoVariant enum (currently ignored) + May raise an `UnknownMetadataException`. """ return self.from_civitai_versionid(int(id)) diff --git a/invokeai/backend/model_manager/metadata/fetch/fetch_base.py b/invokeai/backend/model_manager/metadata/fetch/fetch_base.py index 58b65b69477..5d75493b92f 100644 --- a/invokeai/backend/model_manager/metadata/fetch/fetch_base.py +++ b/invokeai/backend/model_manager/metadata/fetch/fetch_base.py @@ -18,7 +18,9 @@ from pydantic.networks import AnyHttpUrl from requests.sessions import Session -from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator +from invokeai.backend.model_manager import ModelRepoVariant + +from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator, BaseMetadata class ModelMetadataFetchBase(ABC): @@ -45,10 +47,13 @@ def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata: pass @abstractmethod - def from_id(self, id: str) -> AnyModelRepoMetadata: + def from_id(self, id: str, variant: Optional[ModelRepoVariant] = None) -> AnyModelRepoMetadata: """ Given an ID for a model, return a ModelMetadata object. + :param id: An ID. + :param variant: A model variant from the ModelRepoVariant enum. + This method will raise a `UnknownMetadataException` in the event that the requested model's metadata is not found at the provided id. """ @@ -57,5 +62,5 @@ def from_id(self, id: str) -> AnyModelRepoMetadata: @classmethod def from_json(cls, json: str) -> AnyModelRepoMetadata: """Given the JSON representation of the metadata, return the corresponding Pydantic object.""" - metadata = AnyModelRepoMetadataValidator.validate_json(json) + metadata: BaseMetadata = AnyModelRepoMetadataValidator.validate_json(json) # type: ignore return metadata diff --git a/invokeai/backend/model_manager/metadata/fetch/huggingface.py b/invokeai/backend/model_manager/metadata/fetch/huggingface.py index 5d1eb0cc9e4..6f04e8713b2 100644 --- a/invokeai/backend/model_manager/metadata/fetch/huggingface.py +++ b/invokeai/backend/model_manager/metadata/fetch/huggingface.py @@ -19,10 +19,12 @@ import requests from huggingface_hub import HfApi, configure_http_backend, hf_hub_url -from huggingface_hub.utils._errors import RepositoryNotFoundError +from huggingface_hub.utils._errors import RepositoryNotFoundError, RevisionNotFoundError from pydantic.networks import AnyHttpUrl from requests.sessions import Session +from invokeai.backend.model_manager import ModelRepoVariant + from ..metadata_base import ( AnyModelRepoMetadata, HuggingFaceMetadata, @@ -53,12 +55,22 @@ def from_json(cls, json: str) -> HuggingFaceMetadata: metadata = HuggingFaceMetadata.model_validate_json(json) return metadata - def from_id(self, id: str) -> AnyModelRepoMetadata: + def from_id(self, id: str, variant: Optional[ModelRepoVariant] = None) -> AnyModelRepoMetadata: """Return a HuggingFaceMetadata object given the model's repo_id.""" - try: - model_info = HfApi().model_info(repo_id=id, files_metadata=True) - except RepositoryNotFoundError as excp: - raise UnknownMetadataException(f"'{id}' not found. See trace for details.") from excp + # Little loop which tries fetching a revision corresponding to the selected variant. + # If not available, then set variant to None and get the default. + # If this too fails, raise exception. + model_info = None + while not model_info: + try: + model_info = HfApi().model_info(repo_id=id, files_metadata=True, revision=variant) + except RepositoryNotFoundError as excp: + raise UnknownMetadataException(f"'{id}' not found. See trace for details.") from excp + except RevisionNotFoundError: + if variant is None: + raise + else: + variant = None _, name = id.split("/") return HuggingFaceMetadata( @@ -70,7 +82,7 @@ def from_id(self, id: str) -> AnyModelRepoMetadata: tags=model_info.tags, files=[ RemoteModelFile( - url=hf_hub_url(id, x.rfilename), + url=hf_hub_url(id, x.rfilename, revision=variant), path=Path(name, x.rfilename), size=x.size, sha256=x.lfs.get("sha256") if x.lfs else None, diff --git a/invokeai/backend/model_manager/metadata/metadata_base.py b/invokeai/backend/model_manager/metadata/metadata_base.py index 5aa883d26d0..5c3afcdc960 100644 --- a/invokeai/backend/model_manager/metadata/metadata_base.py +++ b/invokeai/backend/model_manager/metadata/metadata_base.py @@ -184,7 +184,6 @@ def download_urls( [x.path for x in self.files], variant, subfolder ) # all files in the model prefix = f"{subfolder}/" if subfolder else "" - # the next step reads model_index.json to determine which subdirectories belong # to the model if Path(f"{prefix}model_index.json") in paths: diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index cd048d2fe78..2c2066d7c52 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -7,6 +7,7 @@ import torch from picklescan.scanner import scan_file_path +import invokeai.backend.util.logging as logger from invokeai.backend.model_management.models.base import read_checkpoint_meta from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat from invokeai.backend.model_management.util import lora_token_vector_length @@ -18,6 +19,7 @@ InvalidModelConfigException, ModelConfigFactory, ModelFormat, + ModelRepoVariant, ModelType, ModelVariantType, SchedulerPredictionType, @@ -28,8 +30,12 @@ LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[SchedulerPredictionType, str]]]] = { BaseModelType.StableDiffusion1: { - ModelVariantType.Normal: "v1-inference.yaml", + ModelVariantType.Normal: { + SchedulerPredictionType.Epsilon: "v1-inference.yaml", + SchedulerPredictionType.VPrediction: "v1-inference-v.yaml", + }, ModelVariantType.Inpaint: "v1-inpainting-inference.yaml", + ModelVariantType.Depth: "v2-midas-inference.yaml", }, BaseModelType.StableDiffusion2: { ModelVariantType.Normal: { @@ -72,6 +78,10 @@ def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]: """Get model scheduler prediction type.""" return None + def get_image_encoder_model_id(self) -> Optional[str]: + """Get image encoder (IP adapters only).""" + return None + class ModelProbe(object): PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = { @@ -147,6 +157,7 @@ def probe( fields["base"] = fields.get("base") or probe.get_base_type() fields["variant"] = fields.get("variant") or probe.get_variant_type() fields["prediction_type"] = fields.get("prediction_type") or probe.get_scheduler_prediction_type() + fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id() fields["name"] = fields.get("name") or cls.get_model_name(model_path) fields["description"] = ( fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}" @@ -155,6 +166,9 @@ def probe( fields["original_hash"] = fields.get("original_hash") or hash fields["current_hash"] = fields.get("current_hash") or hash + if format_type == ModelFormat.Diffusers and hasattr(probe, "get_repo_variant"): + fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant() + # additional fields needed for main and controlnet models if fields["type"] in [ModelType.Main, ModelType.ControlNet] and fields["format"] == ModelFormat.Checkpoint: fields["config"] = cls._get_checkpoint_config_path( @@ -477,6 +491,21 @@ def get_variant_type(self) -> ModelVariantType: def get_format(self) -> ModelFormat: return ModelFormat("diffusers") + def get_repo_variant(self) -> ModelRepoVariant: + # get all files ending in .bin or .safetensors + weight_files = list(self.model_path.glob("**/*.safetensors")) + weight_files.extend(list(self.model_path.glob("**/*.bin"))) + for x in weight_files: + if ".fp16" in x.suffixes: + return ModelRepoVariant.FP16 + if "openvino_model" in x.name: + return ModelRepoVariant.OPENVINO + if "flax_model" in x.name: + return ModelRepoVariant.FLAX + if x.suffix == ".onnx": + return ModelRepoVariant.ONNX + return ModelRepoVariant.DEFAULT + class PipelineFolderProbe(FolderProbeBase): def get_base_type(self) -> BaseModelType: @@ -567,13 +596,20 @@ def get_base_type(self) -> BaseModelType: return TextualInversionCheckpointProbe(path).get_base_type() -class ONNXFolderProbe(FolderProbeBase): +class ONNXFolderProbe(PipelineFolderProbe): + def get_base_type(self) -> BaseModelType: + # Due to the way the installer is set up, the configuration file for safetensors + # will come along for the ride if both the onnx and safetensors forms + # share the same directory. We take advantage of this here. + if (self.model_path / "unet" / "config.json").exists(): + return super().get_base_type() + else: + logger.warning('Base type probing is not implemented for ONNX models. Assuming "sd-1"') + return BaseModelType.StableDiffusion1 + def get_format(self) -> ModelFormat: return ModelFormat("onnx") - def get_base_type(self) -> BaseModelType: - return BaseModelType.StableDiffusion1 - def get_variant_type(self) -> ModelVariantType: return ModelVariantType.Normal @@ -638,6 +674,14 @@ def get_base_type(self) -> BaseModelType: f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}." ) + def get_image_encoder_model_id(self) -> Optional[str]: + encoder_id_path = self.model_path / "image_encoder.txt" + if not encoder_id_path.exists(): + return None + with open(encoder_id_path, "r") as f: + image_encoder_model = f.readline().strip() + return image_encoder_model + class CLIPVisionFolderProbe(FolderProbeBase): def get_base_type(self) -> BaseModelType: diff --git a/invokeai/backend/model_manager/search.py b/invokeai/backend/model_manager/search.py index 4cc3caebe47..0ead22b743f 100644 --- a/invokeai/backend/model_manager/search.py +++ b/invokeai/backend/model_manager/search.py @@ -22,6 +22,7 @@ def find_main_models(model: Path) -> bool: import os from abc import ABC, abstractmethod +from logging import Logger from pathlib import Path from typing import Callable, Optional, Set, Union @@ -29,7 +30,7 @@ def find_main_models(model: Path) -> bool: from invokeai.backend.util.logging import InvokeAILogger -default_logger = InvokeAILogger.get_logger() +default_logger: Logger = InvokeAILogger.get_logger() class SearchStats(BaseModel): @@ -56,7 +57,7 @@ class ModelSearchBase(ABC, BaseModel): on_model_found : Optional[Callable[[Path], bool]] = Field(default=None, description="Called when a model is found.") # noqa E221 on_search_completed : Optional[Callable[[Set[Path]], None]] = Field(default=None, description="Called when search is complete.") # noqa E221 stats : SearchStats = Field(default_factory=SearchStats, description="Summary statistics after search") # noqa E221 - logger : InvokeAILogger = Field(default=default_logger, description="Logger instance.") # noqa E221 + logger : Logger = Field(default=default_logger, description="Logger instance.") # noqa E221 # fmt: on class Config: @@ -115,9 +116,9 @@ class ModelSearch(ModelSearchBase): # returns all models that have 'anime' in the path """ - models_found: Set[Path] = Field(default=None) - scanned_dirs: Set[Path] = Field(default=None) - pruned_paths: Set[Path] = Field(default=None) + models_found: Optional[Set[Path]] = Field(default=None) + scanned_dirs: Optional[Set[Path]] = Field(default=None) + pruned_paths: Optional[Set[Path]] = Field(default=None) def search_started(self) -> None: self.models_found = set() @@ -128,13 +129,13 @@ def search_started(self) -> None: def model_found(self, model: Path) -> None: self.stats.models_found += 1 - if not self.on_model_found or self.on_model_found(model): + if self.on_model_found is None or self.on_model_found(model): self.stats.models_filtered += 1 self.models_found.add(model) def search_completed(self) -> None: - if self.on_search_completed: - self.on_search_completed(self._models_found) + if self.on_search_completed is not None: + self.on_search_completed(self.models_found) def search(self, directory: Union[Path, str]) -> Set[Path]: self._directory = Path(directory) diff --git a/invokeai/backend/model_manager/util/select_hf_files.py b/invokeai/backend/model_manager/util/select_hf_files.py index 69760590440..2fd7a3721ab 100644 --- a/invokeai/backend/model_manager/util/select_hf_files.py +++ b/invokeai/backend/model_manager/util/select_hf_files.py @@ -36,23 +36,37 @@ def filter_files( """ variant = variant or ModelRepoVariant.DEFAULT paths: List[Path] = [] + root = files[0].parts[0] + + # if the subfolder is a single file, then bypass the selection and just return it + if subfolder and subfolder.suffix in [".safetensors", ".bin", ".onnx", ".xml", ".pth", ".pt", ".ckpt", ".msgpack"]: + return [root / subfolder] # Start by filtering on model file extensions, discarding images, docs, etc for file in files: if file.name.endswith((".json", ".txt")): paths.append(file) - elif file.name.endswith(("learned_embeds.bin", "ip_adapter.bin", "lora_weights.safetensors")): + elif file.name.endswith( + ( + "learned_embeds.bin", + "ip_adapter.bin", + "lora_weights.safetensors", + "weights.pb", + "onnx_data", + ) + ): paths.append(file) # BRITTLENESS WARNING!! # Diffusers models always seem to have "model" in their name, and the regex filter below is applied to avoid # downloading random checkpoints that might also be in the repo. However there is no guarantee # that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models - # will adhere to this naming convention, so this is an area of brittleness. + # will adhere to this naming convention, so this is an area to be careful of. elif re.search(r"model(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name): paths.append(file) # limit search to subfolder if requested if subfolder: + subfolder = root / subfolder paths = [x for x in paths if x.parent == Path(subfolder)] # _filter_by_variant uniquifies the paths and returns a set @@ -64,7 +78,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path result = set() basenames: Dict[Path, Path] = {} for path in files: - if path.suffix == ".onnx": + if path.suffix in [".onnx", ".pb", ".onnx_data"]: if variant == ModelRepoVariant.ONNX: result.add(path) diff --git a/invokeai/backend/onnx/onnx_runtime.py b/invokeai/backend/onnx/onnx_runtime.py new file mode 100644 index 00000000000..f79fa015692 --- /dev/null +++ b/invokeai/backend/onnx/onnx_runtime.py @@ -0,0 +1,216 @@ +# Copyright (c) 2024 The InvokeAI Development Team +import os +import sys +from pathlib import Path +from typing import Any, List, Optional, Tuple, Union + +import numpy as np +import onnx +from onnx import numpy_helper +from onnxruntime import InferenceSession, SessionOptions, get_available_providers + +ONNX_WEIGHTS_NAME = "model.onnx" + + +# NOTE FROM LS: This was copied from Stalker's original implementation. +# I have not yet gone through and fixed all the type hints +class IAIOnnxRuntimeModel: + class _tensor_access: + def __init__(self, model): # type: ignore + self.model = model + self.indexes = {} + for idx, obj in enumerate(self.model.proto.graph.initializer): + self.indexes[obj.name] = idx + + def __getitem__(self, key: str): # type: ignore + value = self.model.proto.graph.initializer[self.indexes[key]] + return numpy_helper.to_array(value) + + def __setitem__(self, key: str, value: np.ndarray): # type: ignore + new_node = numpy_helper.from_array(value) + # set_external_data(new_node, location="in-memory-location") + new_node.name = key + # new_node.ClearField("raw_data") + del self.model.proto.graph.initializer[self.indexes[key]] + self.model.proto.graph.initializer.insert(self.indexes[key], new_node) + # self.model.data[key] = OrtValue.ortvalue_from_numpy(value) + + # __delitem__ + + def __contains__(self, key: str) -> bool: + return self.indexes[key] in self.model.proto.graph.initializer + + def items(self) -> List[Tuple[str, Any]]: # fixme + raise NotImplementedError("tensor.items") + # return [(obj.name, obj) for obj in self.raw_proto] + + def keys(self) -> List[str]: + return list(self.indexes.keys()) + + def values(self) -> List[Any]: # fixme + raise NotImplementedError("tensor.values") + # return [obj for obj in self.raw_proto] + + def size(self) -> int: + bytesSum = 0 + for node in self.model.proto.graph.initializer: + bytesSum += sys.getsizeof(node.raw_data) + return bytesSum + + class _access_helper: + def __init__(self, raw_proto): # type: ignore + self.indexes = {} + self.raw_proto = raw_proto + for idx, obj in enumerate(raw_proto): + self.indexes[obj.name] = idx + + def __getitem__(self, key: str): # type: ignore + return self.raw_proto[self.indexes[key]] + + def __setitem__(self, key: str, value): # type: ignore + index = self.indexes[key] + del self.raw_proto[index] + self.raw_proto.insert(index, value) + + # __delitem__ + + def __contains__(self, key: str) -> bool: + return key in self.indexes + + def items(self) -> List[Tuple[str, Any]]: + return [(obj.name, obj) for obj in self.raw_proto] + + def keys(self) -> List[str]: + return list(self.indexes.keys()) + + def values(self) -> List[Any]: # fixme + return list(self.raw_proto) + + def __init__(self, model_path: str, provider: Optional[str]): + self.path = model_path + self.session = None + self.provider = provider + """ + self.data_path = self.path + "_data" + if not os.path.exists(self.data_path): + print(f"Moving model tensors to separate file: {self.data_path}") + tmp_proto = onnx.load(model_path, load_external_data=True) + onnx.save_model(tmp_proto, self.path, save_as_external_data=True, all_tensors_to_one_file=True, location=os.path.basename(self.data_path), size_threshold=1024, convert_attribute=False) + del tmp_proto + gc.collect() + + self.proto = onnx.load(model_path, load_external_data=False) + """ + + self.proto = onnx.load(model_path, load_external_data=True) + # self.data = dict() + # for tensor in self.proto.graph.initializer: + # name = tensor.name + + # if tensor.HasField("raw_data"): + # npt = numpy_helper.to_array(tensor) + # orv = OrtValue.ortvalue_from_numpy(npt) + # # self.data[name] = orv + # # set_external_data(tensor, location="in-memory-location") + # tensor.name = name + # # tensor.ClearField("raw_data") + + self.nodes = self._access_helper(self.proto.graph.node) # type: ignore + # self.initializers = self._access_helper(self.proto.graph.initializer) + # print(self.proto.graph.input) + # print(self.proto.graph.initializer) + + self.tensors = self._tensor_access(self) # type: ignore + + # TODO: integrate with model manager/cache + def create_session(self, height=None, width=None): + if self.session is None or self.session_width != width or self.session_height != height: + # onnx.save(self.proto, "tmp.onnx") + # onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False) + # TODO: something to be able to get weight when they already moved outside of model proto + # (trimmed_model, external_data) = buffer_external_data_tensors(self.proto) + sess = SessionOptions() + # self._external_data.update(**external_data) + # sess.add_external_initializers(list(self.data.keys()), list(self.data.values())) + # sess.enable_profiling = True + + # sess.intra_op_num_threads = 1 + # sess.inter_op_num_threads = 1 + # sess.execution_mode = ExecutionMode.ORT_SEQUENTIAL + # sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL + # sess.enable_cpu_mem_arena = True + # sess.enable_mem_pattern = True + # sess.add_session_config_entry("session.intra_op.use_xnnpack_threadpool", "1") ########### It's the key code + self.session_height = height + self.session_width = width + if height and width: + sess.add_free_dimension_override_by_name("unet_sample_batch", 2) + sess.add_free_dimension_override_by_name("unet_sample_channels", 4) + sess.add_free_dimension_override_by_name("unet_hidden_batch", 2) + sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77) + sess.add_free_dimension_override_by_name("unet_sample_height", self.session_height) + sess.add_free_dimension_override_by_name("unet_sample_width", self.session_width) + sess.add_free_dimension_override_by_name("unet_time_batch", 1) + providers = [] + if self.provider: + providers.append(self.provider) + else: + providers = get_available_providers() + if "TensorrtExecutionProvider" in providers: + providers.remove("TensorrtExecutionProvider") + try: + self.session = InferenceSession(self.proto.SerializeToString(), providers=providers, sess_options=sess) + except Exception as e: + raise e + # self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options) + # self.io_binding = self.session.io_binding() + + def release_session(self): + self.session = None + import gc + + gc.collect() + return + + def __call__(self, **kwargs): + if self.session is None: + raise Exception("You should call create_session before running model") + + inputs = {k: np.array(v) for k, v in kwargs.items()} + # output_names = self.session.get_outputs() + # for k in inputs: + # self.io_binding.bind_cpu_input(k, inputs[k]) + # for name in output_names: + # self.io_binding.bind_output(name.name) + # self.session.run_with_iobinding(self.io_binding, None) + # return self.io_binding.copy_outputs_to_cpu() + return self.session.run(None, inputs) + + # compatability with diffusers load code + @classmethod + def from_pretrained( + cls, + model_id: Union[str, Path], + subfolder: Optional[Union[str, Path]] = None, + file_name: Optional[str] = None, + provider: Optional[str] = None, + sess_options: Optional["SessionOptions"] = None, + **kwargs: Any, + ) -> Any: # fixme + file_name = file_name or ONNX_WEIGHTS_NAME + + if os.path.isdir(model_id): + model_path = model_id + if subfolder is not None: + model_path = os.path.join(model_path, subfolder) + model_path = os.path.join(model_path, file_name) + + else: + model_path = model_id + + # load model from local directory + if not os.path.isfile(model_path): + raise Exception(f"Model not found: {model_path}") + + # TODO: session options + return cls(str(model_path), provider=provider) diff --git a/invokeai/backend/stable_diffusion/__init__.py b/invokeai/backend/stable_diffusion/__init__.py index 212045f81b8..75e6aa0a5de 100644 --- a/invokeai/backend/stable_diffusion/__init__.py +++ b/invokeai/backend/stable_diffusion/__init__.py @@ -4,3 +4,12 @@ from .diffusers_pipeline import PipelineIntermediateState, StableDiffusionGeneratorPipeline # noqa: F401 from .diffusion import InvokeAIDiffuserComponent # noqa: F401 from .diffusion.cross_attention_map_saving import AttentionMapSaver # noqa: F401 +from .seamless import set_seamless # noqa: F401 + +__all__ = [ + "PipelineIntermediateState", + "StableDiffusionGeneratorPipeline", + "InvokeAIDiffuserComponent", + "AttentionMapSaver", + "set_seamless", +] diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 3e38f9f78d5..0676555f7a9 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -32,6 +32,11 @@ def to(self, device, dtype=None): return self +@dataclass +class ConditioningFieldData: + conditionings: List[BasicConditioningInfo] + + @dataclass class SDXLConditioningInfo(BasicConditioningInfo): pooled_embeds: torch.Tensor diff --git a/invokeai/backend/stable_diffusion/schedulers/__init__.py b/invokeai/backend/stable_diffusion/schedulers/__init__.py index a4e9dbf9dad..0b780d3ee27 100644 --- a/invokeai/backend/stable_diffusion/schedulers/__init__.py +++ b/invokeai/backend/stable_diffusion/schedulers/__init__.py @@ -1 +1,3 @@ from .schedulers import SCHEDULER_MAP # noqa: F401 + +__all__ = ["SCHEDULER_MAP"] diff --git a/invokeai/backend/stable_diffusion/seamless.py b/invokeai/backend/stable_diffusion/seamless.py new file mode 100644 index 00000000000..bfdf9e0c536 --- /dev/null +++ b/invokeai/backend/stable_diffusion/seamless.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +from contextlib import contextmanager +from typing import List, Union + +import torch.nn as nn +from diffusers.models import AutoencoderKL, UNet2DConditionModel + + +def _conv_forward_asymmetric(self, input, weight, bias): + """ + Patch for Conv2d._conv_forward that supports asymmetric padding + """ + working = nn.functional.pad(input, self.asymmetric_padding["x"], mode=self.asymmetric_padding_mode["x"]) + working = nn.functional.pad(working, self.asymmetric_padding["y"], mode=self.asymmetric_padding_mode["y"]) + return nn.functional.conv2d( + working, + weight, + bias, + self.stride, + nn.modules.utils._pair(0), + self.dilation, + self.groups, + ) + + +@contextmanager +def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]): + try: + to_restore = [] + + for m_name, m in model.named_modules(): + if isinstance(model, UNet2DConditionModel): + if ".attentions." in m_name: + continue + + if ".resnets." in m_name: + if ".conv2" in m_name: + continue + if ".conv_shortcut" in m_name: + continue + + """ + if isinstance(model, UNet2DConditionModel): + if False and ".upsamplers." in m_name: + continue + + if False and ".downsamplers." in m_name: + continue + + if True and ".resnets." in m_name: + if True and ".conv1" in m_name: + if False and "down_blocks" in m_name: + continue + if False and "mid_block" in m_name: + continue + if False and "up_blocks" in m_name: + continue + + if True and ".conv2" in m_name: + continue + + if True and ".conv_shortcut" in m_name: + continue + + if True and ".attentions." in m_name: + continue + + if False and m_name in ["conv_in", "conv_out"]: + continue + """ + + if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): + m.asymmetric_padding_mode = {} + m.asymmetric_padding = {} + m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant" + m.asymmetric_padding["x"] = ( + m._reversed_padding_repeated_twice[0], + m._reversed_padding_repeated_twice[1], + 0, + 0, + ) + m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant" + m.asymmetric_padding["y"] = ( + 0, + 0, + m._reversed_padding_repeated_twice[2], + m._reversed_padding_repeated_twice[3], + ) + + to_restore.append((m, m._conv_forward)) + m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d) + + yield + + finally: + for module, orig_conv_forward in to_restore: + module._conv_forward = orig_conv_forward + if hasattr(module, "asymmetric_padding_mode"): + del module.asymmetric_padding_mode + if hasattr(module, "asymmetric_padding"): + del module.asymmetric_padding diff --git a/invokeai/backend/tiles/tiles.py b/invokeai/backend/tiles/tiles.py index 3c400fc87ce..2757dadba20 100644 --- a/invokeai/backend/tiles/tiles.py +++ b/invokeai/backend/tiles/tiles.py @@ -3,7 +3,7 @@ import numpy as np -from invokeai.app.invocations.latent import LATENT_SCALE_FACTOR +from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR from invokeai.backend.tiles.utils import TBLR, Tile, paste, seam_blend diff --git a/invokeai/backend/util/__init__.py b/invokeai/backend/util/__init__.py index 87ae1480f54..7b48f0364ea 100644 --- a/invokeai/backend/util/__init__.py +++ b/invokeai/backend/util/__init__.py @@ -12,6 +12,22 @@ torch_dtype, ) from .logging import InvokeAILogger -from .util import Chdir, ask_user, download_with_resume, instantiate_from_config, url_attachment_name # noqa: F401 +from .util import ( # TO DO: Clean this up; remove the unused symbols + GIG, + Chdir, + ask_user, # noqa + directory_size, + download_with_resume, + instantiate_from_config, # noqa + url_attachment_name, # noqa +) -__all__ = ["Chdir", "InvokeAILogger", "choose_precision", "choose_torch_device"] +__all__ = [ + "GIG", + "directory_size", + "Chdir", + "download_with_resume", + "InvokeAILogger", + "choose_precision", + "choose_torch_device", +] diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index d6d3ad727f7..a83d1045f70 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -1,7 +1,7 @@ from __future__ import annotations from contextlib import nullcontext -from typing import Union +from typing import Literal, Optional, Union import torch from torch import autocast @@ -29,12 +29,19 @@ def choose_torch_device() -> torch.device: return torch.device(config.device) -def choose_precision(device: torch.device) -> str: - """Returns an appropriate precision for the given torch device""" +# We are in transition here from using a single global AppConfig to allowing multiple +# configurations. It is strongly recommended to pass the app_config to this function. +def choose_precision( + device: torch.device, app_config: Optional[InvokeAIAppConfig] = None +) -> Literal["float32", "float16", "bfloat16"]: + """Return an appropriate precision for the given torch device.""" + app_config = app_config or config if device.type == "cuda": device_name = torch.cuda.get_device_name(device) if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name): - if config.precision == "bfloat16": + if app_config.precision == "float32": + return "float32" + elif app_config.precision == "bfloat16": return "bfloat16" else: return "float16" @@ -43,8 +50,14 @@ def choose_precision(device: torch.device) -> str: return "float32" -def torch_dtype(device: torch.device) -> torch.dtype: - precision = choose_precision(device) +# We are in transition here from using a single global AppConfig to allowing multiple +# configurations. It is strongly recommended to pass the app_config to this function. +def torch_dtype( + device: Optional[torch.device] = None, + app_config: Optional[InvokeAIAppConfig] = None, +) -> torch.dtype: + device = device or choose_torch_device() + precision = choose_precision(device, app_config) if precision == "float16": return torch.float16 if precision == "bfloat16": diff --git a/invokeai/backend/util/silence_warnings.py b/invokeai/backend/util/silence_warnings.py new file mode 100644 index 00000000000..068b605da97 --- /dev/null +++ b/invokeai/backend/util/silence_warnings.py @@ -0,0 +1,28 @@ +"""Context class to silence transformers and diffusers warnings.""" +import warnings +from typing import Any + +from diffusers import logging as diffusers_logging +from transformers import logging as transformers_logging + + +class SilenceWarnings(object): + """Use in context to temporarily turn off warnings from transformers & diffusers modules. + + with SilenceWarnings(): + # do something + """ + + def __init__(self) -> None: + self.transformers_verbosity = transformers_logging.get_verbosity() + self.diffusers_verbosity = diffusers_logging.get_verbosity() + + def __enter__(self) -> None: + transformers_logging.set_verbosity_error() + diffusers_logging.set_verbosity_error() + warnings.simplefilter("ignore") + + def __exit__(self, *args: Any) -> None: + transformers_logging.set_verbosity(self.transformers_verbosity) + diffusers_logging.set_verbosity(self.diffusers_verbosity) + warnings.simplefilter("default") diff --git a/invokeai/backend/util/test_utils.py b/invokeai/backend/util/test_utils.py index 09b9de9e984..685603cedc6 100644 --- a/invokeai/backend/util/test_utils.py +++ b/invokeai/backend/util/test_utils.py @@ -7,7 +7,7 @@ from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.backend.install.model_install_backend import ModelInstall -from invokeai.backend.model_management.model_manager import ModelInfo +from invokeai.backend.model_management.model_manager import LoadedModelInfo from invokeai.backend.model_management.models.base import BaseModelType, ModelNotFoundException, ModelType, SubModelType @@ -34,8 +34,8 @@ def install_and_load_model( base_model: BaseModelType, model_type: ModelType, submodel_type: Optional[SubModelType] = None, -) -> ModelInfo: - """Install a model if it is not already installed, then get the ModelInfo for that model. +) -> LoadedModelInfo: + """Install a model if it is not already installed, then get the LoadedModelInfo for that model. This is intended as a utility function for tests. @@ -49,9 +49,9 @@ def install_and_load_model( submodel_type (Optional[SubModelType]): The submodel type, forwarded to ModelManager.get_model(...). Returns: - ModelInfo + LoadedModelInfo """ - # If the requested model is already installed, return its ModelInfo. + # If the requested model is already installed, return its LoadedModelInfo. with contextlib.suppress(ModelNotFoundException): return model_installer.mgr.get_model(model_name, base_model, model_type, submodel_type) diff --git a/invokeai/backend/util/util.py b/invokeai/backend/util/util.py index 13751e27702..ae376b41b25 100644 --- a/invokeai/backend/util/util.py +++ b/invokeai/backend/util/util.py @@ -24,6 +24,22 @@ from .devices import torch_dtype +# actual size of a gig +GIG = 1073741824 + + +def directory_size(directory: Path) -> int: + """ + Return the aggregate size of all files in a directory (bytes). + """ + sum = 0 + for root, dirs, files in os.walk(directory): + for f in files: + sum += Path(root, f).stat().st_size + for d in dirs: + sum += Path(root, d).stat().st_size + return sum + def log_txt_as_img(wh, xc, size=10): # wh a tuple of (width, height) diff --git a/invokeai/configs/INITIAL_MODELS.yaml b/invokeai/configs/INITIAL_MODELS.yaml index c230665e3a6..ca2283ab811 100644 --- a/invokeai/configs/INITIAL_MODELS.yaml +++ b/invokeai/configs/INITIAL_MODELS.yaml @@ -1,153 +1,157 @@ # This file predefines a few models that the user may want to install. sd-1/main/stable-diffusion-v1-5: description: Stable Diffusion version 1.5 diffusers model (4.27 GB) - repo_id: runwayml/stable-diffusion-v1-5 + source: runwayml/stable-diffusion-v1-5 recommended: True default: True sd-1/main/stable-diffusion-v1-5-inpainting: description: RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB) - repo_id: runwayml/stable-diffusion-inpainting + source: runwayml/stable-diffusion-inpainting recommended: True sd-2/main/stable-diffusion-2-1: description: Stable Diffusion version 2.1 diffusers model, trained on 768 pixel images (5.21 GB) - repo_id: stabilityai/stable-diffusion-2-1 + source: stabilityai/stable-diffusion-2-1 recommended: False sd-2/main/stable-diffusion-2-inpainting: description: Stable Diffusion version 2.0 inpainting model (5.21 GB) - repo_id: stabilityai/stable-diffusion-2-inpainting + source: stabilityai/stable-diffusion-2-inpainting recommended: False sdxl/main/stable-diffusion-xl-base-1-0: description: Stable Diffusion XL base model (12 GB) - repo_id: stabilityai/stable-diffusion-xl-base-1.0 + source: stabilityai/stable-diffusion-xl-base-1.0 recommended: True sdxl-refiner/main/stable-diffusion-xl-refiner-1-0: description: Stable Diffusion XL refiner model (12 GB) - repo_id: stabilityai/stable-diffusion-xl-refiner-1.0 + source: stabilityai/stable-diffusion-xl-refiner-1.0 recommended: False -sdxl/vae/sdxl-1-0-vae-fix: - description: Fine tuned version of the SDXL-1.0 VAE - repo_id: madebyollin/sdxl-vae-fp16-fix +sdxl/vae/sdxl-vae-fp16-fix: + description: Version of the SDXL-1.0 VAE that works in half precision mode + source: madebyollin/sdxl-vae-fp16-fix recommended: True sd-1/main/Analog-Diffusion: description: An SD-1.5 model trained on diverse analog photographs (2.13 GB) - repo_id: wavymulder/Analog-Diffusion + source: wavymulder/Analog-Diffusion recommended: False -sd-1/main/Deliberate_v5: +sd-1/main/Deliberate: description: Versatile model that produces detailed images up to 768px (4.27 GB) - path: https://huggingface.co/XpucT/Deliberate/resolve/main/Deliberate_v5.safetensors + source: XpucT/Deliberate recommended: False sd-1/main/Dungeons-and-Diffusion: description: Dungeons & Dragons characters (2.13 GB) - repo_id: 0xJustin/Dungeons-and-Diffusion + source: 0xJustin/Dungeons-and-Diffusion recommended: False sd-1/main/dreamlike-photoreal-2: description: A photorealistic model trained on 768 pixel images based on SD 1.5 (2.13 GB) - repo_id: dreamlike-art/dreamlike-photoreal-2.0 + source: dreamlike-art/dreamlike-photoreal-2.0 recommended: False sd-1/main/Inkpunk-Diffusion: description: Stylized illustrations inspired by Gorillaz, FLCL and Shinkawa; prompt with "nvinkpunk" (4.27 GB) - repo_id: Envvi/Inkpunk-Diffusion + source: Envvi/Inkpunk-Diffusion recommended: False sd-1/main/openjourney: description: An SD 1.5 model fine tuned on Midjourney; prompt with "mdjrny-v4 style" (2.13 GB) - repo_id: prompthero/openjourney + source: prompthero/openjourney recommended: False sd-1/main/seek.art_MEGA: - repo_id: coreco/seek.art_MEGA + source: coreco/seek.art_MEGA description: A general use SD-1.5 "anything" model that supports multiple styles (2.1 GB) recommended: False sd-1/main/trinart_stable_diffusion_v2: description: An SD-1.5 model finetuned with ~40K assorted high resolution manga/anime-style images (2.13 GB) - repo_id: naclbit/trinart_stable_diffusion_v2 + source: naclbit/trinart_stable_diffusion_v2 recommended: False sd-1/controlnet/qrcode_monster: - repo_id: monster-labs/control_v1p_sd15_qrcode_monster + source: monster-labs/control_v1p_sd15_qrcode_monster subfolder: v2 sd-1/controlnet/canny: - repo_id: lllyasviel/control_v11p_sd15_canny + source: lllyasviel/control_v11p_sd15_canny recommended: True sd-1/controlnet/inpaint: - repo_id: lllyasviel/control_v11p_sd15_inpaint + source: lllyasviel/control_v11p_sd15_inpaint sd-1/controlnet/mlsd: - repo_id: lllyasviel/control_v11p_sd15_mlsd + source: lllyasviel/control_v11p_sd15_mlsd sd-1/controlnet/depth: - repo_id: lllyasviel/control_v11f1p_sd15_depth + source: lllyasviel/control_v11f1p_sd15_depth recommended: True sd-1/controlnet/normal_bae: - repo_id: lllyasviel/control_v11p_sd15_normalbae + source: lllyasviel/control_v11p_sd15_normalbae sd-1/controlnet/seg: - repo_id: lllyasviel/control_v11p_sd15_seg + source: lllyasviel/control_v11p_sd15_seg sd-1/controlnet/lineart: - repo_id: lllyasviel/control_v11p_sd15_lineart + source: lllyasviel/control_v11p_sd15_lineart recommended: True sd-1/controlnet/lineart_anime: - repo_id: lllyasviel/control_v11p_sd15s2_lineart_anime + source: lllyasviel/control_v11p_sd15s2_lineart_anime sd-1/controlnet/openpose: - repo_id: lllyasviel/control_v11p_sd15_openpose + source: lllyasviel/control_v11p_sd15_openpose recommended: True sd-1/controlnet/scribble: - repo_id: lllyasviel/control_v11p_sd15_scribble + source: lllyasviel/control_v11p_sd15_scribble recommended: False sd-1/controlnet/softedge: - repo_id: lllyasviel/control_v11p_sd15_softedge + source: lllyasviel/control_v11p_sd15_softedge sd-1/controlnet/shuffle: - repo_id: lllyasviel/control_v11e_sd15_shuffle + source: lllyasviel/control_v11e_sd15_shuffle sd-1/controlnet/tile: - repo_id: lllyasviel/control_v11f1e_sd15_tile + source: lllyasviel/control_v11f1e_sd15_tile sd-1/controlnet/ip2p: - repo_id: lllyasviel/control_v11e_sd15_ip2p + source: lllyasviel/control_v11e_sd15_ip2p sd-1/t2i_adapter/canny-sd15: - repo_id: TencentARC/t2iadapter_canny_sd15v2 + source: TencentARC/t2iadapter_canny_sd15v2 sd-1/t2i_adapter/sketch-sd15: - repo_id: TencentARC/t2iadapter_sketch_sd15v2 + source: TencentARC/t2iadapter_sketch_sd15v2 sd-1/t2i_adapter/depth-sd15: - repo_id: TencentARC/t2iadapter_depth_sd15v2 + source: TencentARC/t2iadapter_depth_sd15v2 sd-1/t2i_adapter/zoedepth-sd15: - repo_id: TencentARC/t2iadapter_zoedepth_sd15v1 + source: TencentARC/t2iadapter_zoedepth_sd15v1 sdxl/t2i_adapter/canny-sdxl: - repo_id: TencentARC/t2i-adapter-canny-sdxl-1.0 + source: TencentARC/t2i-adapter-canny-sdxl-1.0 sdxl/t2i_adapter/zoedepth-sdxl: - repo_id: TencentARC/t2i-adapter-depth-zoe-sdxl-1.0 + source: TencentARC/t2i-adapter-depth-zoe-sdxl-1.0 sdxl/t2i_adapter/lineart-sdxl: - repo_id: TencentARC/t2i-adapter-lineart-sdxl-1.0 + source: TencentARC/t2i-adapter-lineart-sdxl-1.0 sdxl/t2i_adapter/sketch-sdxl: - repo_id: TencentARC/t2i-adapter-sketch-sdxl-1.0 + source: TencentARC/t2i-adapter-sketch-sdxl-1.0 sd-1/embedding/EasyNegative: - path: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors + source: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors recommended: True -sd-1/embedding/ahx-beta-453407d: - repo_id: sd-concepts-library/ahx-beta-453407d + description: A textual inversion to use in the negative prompt to reduce bad anatomy +sd-1/lora/FlatColor: + source: https://civitai.com/models/6433/loraflatcolor + recommended: True + description: A LoRA that generates scenery using solid blocks of color sd-1/lora/Ink scenery: - path: https://civitai.com/api/download/models/83390 + source: https://civitai.com/api/download/models/83390 + description: Generate india ink-like landscapes sd-1/ip_adapter/ip_adapter_sd15: - repo_id: InvokeAI/ip_adapter_sd15 + source: InvokeAI/ip_adapter_sd15 recommended: True requires: - InvokeAI/ip_adapter_sd_image_encoder description: IP-Adapter for SD 1.5 models sd-1/ip_adapter/ip_adapter_plus_sd15: - repo_id: InvokeAI/ip_adapter_plus_sd15 + source: InvokeAI/ip_adapter_plus_sd15 recommended: False requires: - InvokeAI/ip_adapter_sd_image_encoder description: Refined IP-Adapter for SD 1.5 models sd-1/ip_adapter/ip_adapter_plus_face_sd15: - repo_id: InvokeAI/ip_adapter_plus_face_sd15 + source: InvokeAI/ip_adapter_plus_face_sd15 recommended: False requires: - InvokeAI/ip_adapter_sd_image_encoder description: Refined IP-Adapter for SD 1.5 models, adapted for faces sdxl/ip_adapter/ip_adapter_sdxl: - repo_id: InvokeAI/ip_adapter_sdxl + source: InvokeAI/ip_adapter_sdxl recommended: False requires: - InvokeAI/ip_adapter_sdxl_image_encoder description: IP-Adapter for SDXL models any/clip_vision/ip_adapter_sd_image_encoder: - repo_id: InvokeAI/ip_adapter_sd_image_encoder + source: InvokeAI/ip_adapter_sd_image_encoder recommended: False description: Required model for using IP-Adapters with SD-1/2 models any/clip_vision/ip_adapter_sdxl_image_encoder: - repo_id: InvokeAI/ip_adapter_sdxl_image_encoder + source: InvokeAI/ip_adapter_sdxl_image_encoder recommended: False description: Required model for using IP-Adapters with SDXL models diff --git a/invokeai/configs/INITIAL_MODELS2.yaml b/invokeai/configs/INITIAL_MODELS.yaml.OLD similarity index 59% rename from invokeai/configs/INITIAL_MODELS2.yaml rename to invokeai/configs/INITIAL_MODELS.yaml.OLD index ca2283ab811..c230665e3a6 100644 --- a/invokeai/configs/INITIAL_MODELS2.yaml +++ b/invokeai/configs/INITIAL_MODELS.yaml.OLD @@ -1,157 +1,153 @@ # This file predefines a few models that the user may want to install. sd-1/main/stable-diffusion-v1-5: description: Stable Diffusion version 1.5 diffusers model (4.27 GB) - source: runwayml/stable-diffusion-v1-5 + repo_id: runwayml/stable-diffusion-v1-5 recommended: True default: True sd-1/main/stable-diffusion-v1-5-inpainting: description: RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB) - source: runwayml/stable-diffusion-inpainting + repo_id: runwayml/stable-diffusion-inpainting recommended: True sd-2/main/stable-diffusion-2-1: description: Stable Diffusion version 2.1 diffusers model, trained on 768 pixel images (5.21 GB) - source: stabilityai/stable-diffusion-2-1 + repo_id: stabilityai/stable-diffusion-2-1 recommended: False sd-2/main/stable-diffusion-2-inpainting: description: Stable Diffusion version 2.0 inpainting model (5.21 GB) - source: stabilityai/stable-diffusion-2-inpainting + repo_id: stabilityai/stable-diffusion-2-inpainting recommended: False sdxl/main/stable-diffusion-xl-base-1-0: description: Stable Diffusion XL base model (12 GB) - source: stabilityai/stable-diffusion-xl-base-1.0 + repo_id: stabilityai/stable-diffusion-xl-base-1.0 recommended: True sdxl-refiner/main/stable-diffusion-xl-refiner-1-0: description: Stable Diffusion XL refiner model (12 GB) - source: stabilityai/stable-diffusion-xl-refiner-1.0 + repo_id: stabilityai/stable-diffusion-xl-refiner-1.0 recommended: False -sdxl/vae/sdxl-vae-fp16-fix: - description: Version of the SDXL-1.0 VAE that works in half precision mode - source: madebyollin/sdxl-vae-fp16-fix +sdxl/vae/sdxl-1-0-vae-fix: + description: Fine tuned version of the SDXL-1.0 VAE + repo_id: madebyollin/sdxl-vae-fp16-fix recommended: True sd-1/main/Analog-Diffusion: description: An SD-1.5 model trained on diverse analog photographs (2.13 GB) - source: wavymulder/Analog-Diffusion + repo_id: wavymulder/Analog-Diffusion recommended: False -sd-1/main/Deliberate: +sd-1/main/Deliberate_v5: description: Versatile model that produces detailed images up to 768px (4.27 GB) - source: XpucT/Deliberate + path: https://huggingface.co/XpucT/Deliberate/resolve/main/Deliberate_v5.safetensors recommended: False sd-1/main/Dungeons-and-Diffusion: description: Dungeons & Dragons characters (2.13 GB) - source: 0xJustin/Dungeons-and-Diffusion + repo_id: 0xJustin/Dungeons-and-Diffusion recommended: False sd-1/main/dreamlike-photoreal-2: description: A photorealistic model trained on 768 pixel images based on SD 1.5 (2.13 GB) - source: dreamlike-art/dreamlike-photoreal-2.0 + repo_id: dreamlike-art/dreamlike-photoreal-2.0 recommended: False sd-1/main/Inkpunk-Diffusion: description: Stylized illustrations inspired by Gorillaz, FLCL and Shinkawa; prompt with "nvinkpunk" (4.27 GB) - source: Envvi/Inkpunk-Diffusion + repo_id: Envvi/Inkpunk-Diffusion recommended: False sd-1/main/openjourney: description: An SD 1.5 model fine tuned on Midjourney; prompt with "mdjrny-v4 style" (2.13 GB) - source: prompthero/openjourney + repo_id: prompthero/openjourney recommended: False sd-1/main/seek.art_MEGA: - source: coreco/seek.art_MEGA + repo_id: coreco/seek.art_MEGA description: A general use SD-1.5 "anything" model that supports multiple styles (2.1 GB) recommended: False sd-1/main/trinart_stable_diffusion_v2: description: An SD-1.5 model finetuned with ~40K assorted high resolution manga/anime-style images (2.13 GB) - source: naclbit/trinart_stable_diffusion_v2 + repo_id: naclbit/trinart_stable_diffusion_v2 recommended: False sd-1/controlnet/qrcode_monster: - source: monster-labs/control_v1p_sd15_qrcode_monster + repo_id: monster-labs/control_v1p_sd15_qrcode_monster subfolder: v2 sd-1/controlnet/canny: - source: lllyasviel/control_v11p_sd15_canny + repo_id: lllyasviel/control_v11p_sd15_canny recommended: True sd-1/controlnet/inpaint: - source: lllyasviel/control_v11p_sd15_inpaint + repo_id: lllyasviel/control_v11p_sd15_inpaint sd-1/controlnet/mlsd: - source: lllyasviel/control_v11p_sd15_mlsd + repo_id: lllyasviel/control_v11p_sd15_mlsd sd-1/controlnet/depth: - source: lllyasviel/control_v11f1p_sd15_depth + repo_id: lllyasviel/control_v11f1p_sd15_depth recommended: True sd-1/controlnet/normal_bae: - source: lllyasviel/control_v11p_sd15_normalbae + repo_id: lllyasviel/control_v11p_sd15_normalbae sd-1/controlnet/seg: - source: lllyasviel/control_v11p_sd15_seg + repo_id: lllyasviel/control_v11p_sd15_seg sd-1/controlnet/lineart: - source: lllyasviel/control_v11p_sd15_lineart + repo_id: lllyasviel/control_v11p_sd15_lineart recommended: True sd-1/controlnet/lineart_anime: - source: lllyasviel/control_v11p_sd15s2_lineart_anime + repo_id: lllyasviel/control_v11p_sd15s2_lineart_anime sd-1/controlnet/openpose: - source: lllyasviel/control_v11p_sd15_openpose + repo_id: lllyasviel/control_v11p_sd15_openpose recommended: True sd-1/controlnet/scribble: - source: lllyasviel/control_v11p_sd15_scribble + repo_id: lllyasviel/control_v11p_sd15_scribble recommended: False sd-1/controlnet/softedge: - source: lllyasviel/control_v11p_sd15_softedge + repo_id: lllyasviel/control_v11p_sd15_softedge sd-1/controlnet/shuffle: - source: lllyasviel/control_v11e_sd15_shuffle + repo_id: lllyasviel/control_v11e_sd15_shuffle sd-1/controlnet/tile: - source: lllyasviel/control_v11f1e_sd15_tile + repo_id: lllyasviel/control_v11f1e_sd15_tile sd-1/controlnet/ip2p: - source: lllyasviel/control_v11e_sd15_ip2p + repo_id: lllyasviel/control_v11e_sd15_ip2p sd-1/t2i_adapter/canny-sd15: - source: TencentARC/t2iadapter_canny_sd15v2 + repo_id: TencentARC/t2iadapter_canny_sd15v2 sd-1/t2i_adapter/sketch-sd15: - source: TencentARC/t2iadapter_sketch_sd15v2 + repo_id: TencentARC/t2iadapter_sketch_sd15v2 sd-1/t2i_adapter/depth-sd15: - source: TencentARC/t2iadapter_depth_sd15v2 + repo_id: TencentARC/t2iadapter_depth_sd15v2 sd-1/t2i_adapter/zoedepth-sd15: - source: TencentARC/t2iadapter_zoedepth_sd15v1 + repo_id: TencentARC/t2iadapter_zoedepth_sd15v1 sdxl/t2i_adapter/canny-sdxl: - source: TencentARC/t2i-adapter-canny-sdxl-1.0 + repo_id: TencentARC/t2i-adapter-canny-sdxl-1.0 sdxl/t2i_adapter/zoedepth-sdxl: - source: TencentARC/t2i-adapter-depth-zoe-sdxl-1.0 + repo_id: TencentARC/t2i-adapter-depth-zoe-sdxl-1.0 sdxl/t2i_adapter/lineart-sdxl: - source: TencentARC/t2i-adapter-lineart-sdxl-1.0 + repo_id: TencentARC/t2i-adapter-lineart-sdxl-1.0 sdxl/t2i_adapter/sketch-sdxl: - source: TencentARC/t2i-adapter-sketch-sdxl-1.0 + repo_id: TencentARC/t2i-adapter-sketch-sdxl-1.0 sd-1/embedding/EasyNegative: - source: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors + path: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors recommended: True - description: A textual inversion to use in the negative prompt to reduce bad anatomy -sd-1/lora/FlatColor: - source: https://civitai.com/models/6433/loraflatcolor - recommended: True - description: A LoRA that generates scenery using solid blocks of color +sd-1/embedding/ahx-beta-453407d: + repo_id: sd-concepts-library/ahx-beta-453407d sd-1/lora/Ink scenery: - source: https://civitai.com/api/download/models/83390 - description: Generate india ink-like landscapes + path: https://civitai.com/api/download/models/83390 sd-1/ip_adapter/ip_adapter_sd15: - source: InvokeAI/ip_adapter_sd15 + repo_id: InvokeAI/ip_adapter_sd15 recommended: True requires: - InvokeAI/ip_adapter_sd_image_encoder description: IP-Adapter for SD 1.5 models sd-1/ip_adapter/ip_adapter_plus_sd15: - source: InvokeAI/ip_adapter_plus_sd15 + repo_id: InvokeAI/ip_adapter_plus_sd15 recommended: False requires: - InvokeAI/ip_adapter_sd_image_encoder description: Refined IP-Adapter for SD 1.5 models sd-1/ip_adapter/ip_adapter_plus_face_sd15: - source: InvokeAI/ip_adapter_plus_face_sd15 + repo_id: InvokeAI/ip_adapter_plus_face_sd15 recommended: False requires: - InvokeAI/ip_adapter_sd_image_encoder description: Refined IP-Adapter for SD 1.5 models, adapted for faces sdxl/ip_adapter/ip_adapter_sdxl: - source: InvokeAI/ip_adapter_sdxl + repo_id: InvokeAI/ip_adapter_sdxl recommended: False requires: - InvokeAI/ip_adapter_sdxl_image_encoder description: IP-Adapter for SDXL models any/clip_vision/ip_adapter_sd_image_encoder: - source: InvokeAI/ip_adapter_sd_image_encoder + repo_id: InvokeAI/ip_adapter_sd_image_encoder recommended: False description: Required model for using IP-Adapters with SD-1/2 models any/clip_vision/ip_adapter_sdxl_image_encoder: - source: InvokeAI/ip_adapter_sdxl_image_encoder + repo_id: InvokeAI/ip_adapter_sdxl_image_encoder recommended: False description: Required model for using IP-Adapters with SDXL models diff --git a/invokeai/frontend/install/model_install.py b/invokeai/frontend/install/model_install.py index e23538ffd66..20b630dfc62 100644 --- a/invokeai/frontend/install/model_install.py +++ b/invokeai/frontend/install/model_install.py @@ -6,47 +6,45 @@ """ This is the npyscreen frontend to the model installation application. -The work is actually done in backend code in model_install_backend.py. +It is currently named model_install2.py, but will ultimately replace model_install.py. """ import argparse import curses -import logging import sys -import textwrap import traceback +import warnings from argparse import Namespace -from multiprocessing import Process -from multiprocessing.connection import Connection, Pipe -from pathlib import Path from shutil import get_terminal_size -from typing import Optional +from typing import Any, Dict, List, Optional, Set import npyscreen import torch from npyscreen import widget from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.install.model_install_backend import InstallSelections, ModelInstall, SchedulerPredictionType -from invokeai.backend.model_management import ModelManager, ModelType +from invokeai.app.services.model_install import ModelInstallServiceBase +from invokeai.backend.install.install_helper import InstallHelper, InstallSelections, UnifiedModelInfo +from invokeai.backend.model_manager import ModelType from invokeai.backend.util import choose_precision, choose_torch_device from invokeai.backend.util.logging import InvokeAILogger from invokeai.frontend.install.widgets import ( MIN_COLS, MIN_LINES, - BufferBox, CenteredTitleText, CyclingForm, MultiSelectColumns, SingleSelectColumns, TextBox, WindowTooSmallException, - select_stable_diffusion_config_file, set_min_terminal_size, ) +warnings.filterwarnings("ignore", category=UserWarning) # noqa: E402 config = InvokeAIAppConfig.get_config() -logger = InvokeAILogger.get_logger() +logger = InvokeAILogger.get_logger("ModelInstallService") +logger.setLevel("WARNING") +# logger.setLevel('DEBUG') # build a table mapping all non-printable characters to None # for stripping control characters @@ -58,44 +56,42 @@ def make_printable(s: str) -> str: - """Replace non-printable characters in a string""" + """Replace non-printable characters in a string.""" return s.translate(NOPRINT_TRANS_TABLE) class addModelsForm(CyclingForm, npyscreen.FormMultiPage): + """Main form for interactive TUI.""" + # for responsive resizing set to False, but this seems to cause a crash! FIX_MINIMUM_SIZE_WHEN_CREATED = True # for persistence current_tab = 0 - def __init__(self, parentApp, name, multipage=False, *args, **keywords): + def __init__(self, parentApp: npyscreen.NPSAppManaged, name: str, multipage: bool = False, **keywords: Any): self.multipage = multipage self.subprocess = None - super().__init__(parentApp=parentApp, name=name, *args, **keywords) # noqa: B026 # TODO: maybe this is bad? + super().__init__(parentApp=parentApp, name=name, **keywords) - def create(self): + def create(self) -> None: + self.installer = self.parentApp.install_helper.installer + self.model_labels = self._get_model_labels() self.keypress_timeout = 10 self.counter = 0 self.subprocess_connection = None - if not config.model_conf_path.exists(): - with open(config.model_conf_path, "w") as file: - print("# InvokeAI model configuration file", file=file) - self.installer = ModelInstall(config) - self.all_models = self.installer.all_models() - self.starter_models = self.installer.starter_models() - self.model_labels = self._get_model_labels() window_width, window_height = get_terminal_size() - self.nextrely -= 1 + # npyscreen has no typing hints + self.nextrely -= 1 # type: ignore self.add_widget_intelligent( npyscreen.FixedText, value="Use ctrl-N and ctrl-P to move to the ext and

revious fields. Cursor keys navigate, and selects.", editable=False, color="CAUTION", ) - self.nextrely += 1 + self.nextrely += 1 # type: ignore self.tabs = self.add_widget_intelligent( SingleSelectColumns, values=[ @@ -115,9 +111,9 @@ def create(self): ) self.tabs.on_changed = self._toggle_tables - top_of_table = self.nextrely + top_of_table = self.nextrely # type: ignore self.starter_pipelines = self.add_starter_pipelines() - bottom_of_table = self.nextrely + bottom_of_table = self.nextrely # type: ignore self.nextrely = top_of_table self.pipeline_models = self.add_pipeline_widgets( @@ -162,15 +158,7 @@ def create(self): self.nextrely = bottom_of_table + 1 - self.monitor = self.add_widget_intelligent( - BufferBox, - name="Log Messages", - editable=False, - max_height=6, - ) - self.nextrely += 1 - done_label = "APPLY CHANGES" back_label = "BACK" cancel_label = "CANCEL" current_position = self.nextrely @@ -186,14 +174,8 @@ def create(self): npyscreen.ButtonPress, name=cancel_label, when_pressed_function=self.on_cancel ) self.nextrely = current_position - self.ok_button = self.add_widget_intelligent( - npyscreen.ButtonPress, - name=done_label, - relx=(window_width - len(done_label)) // 2, - when_pressed_function=self.on_execute, - ) - label = "APPLY CHANGES & EXIT" + label = "APPLY CHANGES" self.nextrely = current_position self.done = self.add_widget_intelligent( npyscreen.ButtonPress, @@ -210,17 +192,16 @@ def create(self): ############# diffusers tab ########## def add_starter_pipelines(self) -> dict[str, npyscreen.widget]: """Add widgets responsible for selecting diffusers models""" - widgets = {} - models = self.all_models - starters = self.starter_models - starter_model_labels = self.model_labels + widgets: Dict[str, npyscreen.widget] = {} - self.installed_models = sorted([x for x in starters if models[x].installed]) + all_models = self.all_models # master dict of all models, indexed by key + model_list = [x for x in self.starter_models if all_models[x].type in ["main", "vae"]] + model_labels = [self.model_labels[x] for x in model_list] widgets.update( label1=self.add_widget_intelligent( CenteredTitleText, - name="Select from a starter set of Stable Diffusion models from HuggingFace.", + name="Select from a starter set of Stable Diffusion models from HuggingFace and Civitae.", editable=False, labelColor="CAUTION", ) @@ -230,23 +211,24 @@ def add_starter_pipelines(self) -> dict[str, npyscreen.widget]: # if user has already installed some initial models, then don't patronize them # by showing more recommendations show_recommended = len(self.installed_models) == 0 - keys = [x for x in models.keys() if x in starters] + + checked = [ + model_list.index(x) + for x in model_list + if (show_recommended and all_models[x].recommended) or all_models[x].installed + ] widgets.update( models_selected=self.add_widget_intelligent( MultiSelectColumns, columns=1, name="Install Starter Models", - values=[starter_model_labels[x] for x in keys], - value=[ - keys.index(x) - for x in keys - if (show_recommended and models[x].recommended) or (x in self.installed_models) - ], - max_height=len(starters) + 1, + values=model_labels, + value=checked, + max_height=len(model_list) + 1, relx=4, scroll_exit=True, ), - models=keys, + models=model_list, ) self.nextrely += 1 @@ -257,14 +239,18 @@ def add_model_widgets( self, model_type: ModelType, window_width: int = 120, - install_prompt: str = None, - exclude: set = None, + install_prompt: Optional[str] = None, + exclude: Optional[Set[str]] = None, ) -> dict[str, npyscreen.widget]: """Generic code to create model selection widgets""" if exclude is None: exclude = set() - widgets = {} - model_list = [x for x in self.all_models if self.all_models[x].model_type == model_type and x not in exclude] + widgets: Dict[str, npyscreen.widget] = {} + all_models = self.all_models + model_list = sorted( + [x for x in all_models if all_models[x].type == model_type and x not in exclude], + key=lambda x: all_models[x].name or "", + ) model_labels = [self.model_labels[x] for x in model_list] show_recommended = len(self.installed_models) == 0 @@ -300,7 +286,7 @@ def add_model_widgets( value=[ model_list.index(x) for x in model_list - if (show_recommended and self.all_models[x].recommended) or self.all_models[x].installed + if (show_recommended and all_models[x].recommended) or all_models[x].installed ], max_height=len(model_list) // columns + 1, relx=4, @@ -324,7 +310,7 @@ def add_model_widgets( download_ids=self.add_widget_intelligent( TextBox, name="Additional URLs, or HuggingFace repo_ids to install (Space separated. Use shift-control-V to paste):", - max_height=4, + max_height=6, scroll_exit=True, editable=True, ) @@ -349,13 +335,13 @@ def add_pipeline_widgets( return widgets - def resize(self): + def resize(self) -> None: super().resize() if s := self.starter_pipelines.get("models_selected"): - keys = [x for x in self.all_models.keys() if x in self.starter_models] - s.values = [self.model_labels[x] for x in keys] + if model_list := self.starter_pipelines.get("models"): + s.values = [self.model_labels[x] for x in model_list] - def _toggle_tables(self, value=None): + def _toggle_tables(self, value: List[int]) -> None: selected_tab = value[0] widgets = [ self.starter_pipelines, @@ -385,17 +371,18 @@ def _toggle_tables(self, value=None): self.display() def _get_model_labels(self) -> dict[str, str]: + """Return a list of trimmed labels for all models.""" window_width, window_height = get_terminal_size() checkbox_width = 4 spacing_width = 2 + result = {} models = self.all_models - label_width = max([len(models[x].name) for x in models]) + label_width = max([len(models[x].name or "") for x in self.starter_models]) description_width = window_width - label_width - checkbox_width - spacing_width - result = {} - for x in models.keys(): - description = models[x].description + for key in self.all_models: + description = models[key].description description = ( description[0 : description_width - 3] + "..." if description and len(description) > description_width @@ -403,7 +390,8 @@ def _get_model_labels(self) -> dict[str, str]: if description else "" ) - result[x] = f"%-{label_width}s %s" % (models[x].name, description) + result[key] = f"%-{label_width}s %s" % (models[key].name, description) + return result def _get_columns(self) -> int: @@ -413,50 +401,40 @@ def _get_columns(self) -> int: def confirm_deletions(self, selections: InstallSelections) -> bool: remove_models = selections.remove_models - if len(remove_models) > 0: - mods = "\n".join([ModelManager.parse_key(x)[0] for x in remove_models]) - return npyscreen.notify_ok_cancel( + if remove_models: + model_names = [self.all_models[x].name or "" for x in remove_models] + mods = "\n".join(model_names) + is_ok = npyscreen.notify_ok_cancel( f"These unchecked models will be deleted from disk. Continue?\n---------\n{mods}" ) + assert isinstance(is_ok, bool) # npyscreen doesn't have return type annotations + return is_ok else: return True - def on_execute(self): - self.marshall_arguments() - app = self.parentApp - if not self.confirm_deletions(app.install_selections): - return + @property + def all_models(self) -> Dict[str, UnifiedModelInfo]: + # npyscreen doesn't having typing hints + return self.parentApp.install_helper.all_models # type: ignore - self.monitor.entry_widget.buffer(["Processing..."], scroll_end=True) - self.ok_button.hidden = True - self.display() + @property + def starter_models(self) -> List[str]: + return self.parentApp.install_helper._starter_models # type: ignore - # TO DO: Spawn a worker thread, not a subprocess - parent_conn, child_conn = Pipe() - p = Process( - target=process_and_execute, - kwargs={ - "opt": app.program_opts, - "selections": app.install_selections, - "conn_out": child_conn, - }, - ) - p.start() - child_conn.close() - self.subprocess_connection = parent_conn - self.subprocess = p - app.install_selections = InstallSelections() + @property + def installed_models(self) -> List[str]: + return self.parentApp.install_helper._installed_models # type: ignore - def on_back(self): + def on_back(self) -> None: self.parentApp.switchFormPrevious() self.editing = False - def on_cancel(self): + def on_cancel(self) -> None: self.parentApp.setNextForm(None) self.parentApp.user_cancelled = True self.editing = False - def on_done(self): + def on_done(self) -> None: self.marshall_arguments() if not self.confirm_deletions(self.parentApp.install_selections): return @@ -464,77 +442,7 @@ def on_done(self): self.parentApp.user_cancelled = False self.editing = False - ########## This routine monitors the child process that is performing model installation and removal ##### - def while_waiting(self): - """Called during idle periods. Main task is to update the Log Messages box with messages - from the child process that does the actual installation/removal""" - c = self.subprocess_connection - if not c: - return - - monitor_widget = self.monitor.entry_widget - while c.poll(): - try: - data = c.recv_bytes().decode("utf-8") - data.strip("\n") - - # processing child is requesting user input to select the - # right configuration file - if data.startswith("*need v2 config"): - _, model_path, *_ = data.split(":", 2) - self._return_v2_config(model_path) - - # processing child is done - elif data == "*done*": - self._close_subprocess_and_regenerate_form() - break - - # update the log message box - else: - data = make_printable(data) - data = data.replace("[A", "") - monitor_widget.buffer( - textwrap.wrap( - data, - width=monitor_widget.width, - subsequent_indent=" ", - ), - scroll_end=True, - ) - self.display() - except (EOFError, OSError): - self.subprocess_connection = None - - def _return_v2_config(self, model_path: str): - c = self.subprocess_connection - model_name = Path(model_path).name - message = select_stable_diffusion_config_file(model_name=model_name) - c.send_bytes(message.encode("utf-8")) - - def _close_subprocess_and_regenerate_form(self): - app = self.parentApp - self.subprocess_connection.close() - self.subprocess_connection = None - self.monitor.entry_widget.buffer(["** Action Complete **"]) - self.display() - - # rebuild the form, saving and restoring some of the fields that need to be preserved. - saved_messages = self.monitor.entry_widget.values - - app.main_form = app.addForm( - "MAIN", - addModelsForm, - name="Install Stable Diffusion Models", - multipage=self.multipage, - ) - app.switchForm("MAIN") - - app.main_form.monitor.entry_widget.values = saved_messages - app.main_form.monitor.entry_widget.buffer([""], scroll_end=True) - # app.main_form.pipeline_models['autoload_directory'].value = autoload_dir - # app.main_form.pipeline_models['autoscan_on_startup'].value = autoscan - - def marshall_arguments(self): + def marshall_arguments(self) -> None: """ Assemble arguments and store as attributes of the application: .starter_models: dict of model names to install from INITIAL_CONFIGURE.yaml @@ -564,46 +472,24 @@ def marshall_arguments(self): models_to_install = [x for x in selected if not self.all_models[x].installed] models_to_remove = [x for x in section["models"] if x not in selected and self.all_models[x].installed] selections.remove_models.extend(models_to_remove) - selections.install_models.extend( - all_models[x].path or all_models[x].repo_id - for x in models_to_install - if all_models[x].path or all_models[x].repo_id - ) + selections.install_models.extend([all_models[x] for x in models_to_install]) # models located in the 'download_ids" section for section in ui_sections: if downloads := section.get("download_ids"): - selections.install_models.extend(downloads.value.split()) - - # NOT NEEDED - DONE IN BACKEND NOW - # # special case for the ipadapter_models. If any of the adapters are - # # chosen, then we add the corresponding encoder(s) to the install list. - # section = self.ipadapter_models - # if section.get("models_selected"): - # selected_adapters = [ - # self.all_models[section["models"][x]].name for x in section.get("models_selected").value - # ] - # encoders = [] - # if any(["sdxl" in x for x in selected_adapters]): - # encoders.append("ip_adapter_sdxl_image_encoder") - # if any(["sd15" in x for x in selected_adapters]): - # encoders.append("ip_adapter_sd_image_encoder") - # for encoder in encoders: - # key = f"any/clip_vision/{encoder}" - # repo_id = f"InvokeAI/{encoder}" - # if key not in self.all_models: - # selections.install_models.append(repo_id) - - -class AddModelApplication(npyscreen.NPSAppManaged): - def __init__(self, opt): + models = [UnifiedModelInfo(source=x) for x in downloads.value.split()] + selections.install_models.extend(models) + + +class AddModelApplication(npyscreen.NPSAppManaged): # type: ignore + def __init__(self, opt: Namespace, install_helper: InstallHelper): super().__init__() self.program_opts = opt self.user_cancelled = False - # self.autoload_pending = True self.install_selections = InstallSelections() + self.install_helper = install_helper - def onStart(self): + def onStart(self) -> None: npyscreen.setTheme(npyscreen.Themes.DefaultTheme) self.main_form = self.addForm( "MAIN", @@ -613,138 +499,62 @@ def onStart(self): ) -class StderrToMessage: - def __init__(self, connection: Connection): - self.connection = connection - - def write(self, data: str): - self.connection.send_bytes(data.encode("utf-8")) - - def flush(self): - pass +def list_models(installer: ModelInstallServiceBase, model_type: ModelType): + """Print out all models of type model_type.""" + models = installer.record_store.search_by_attr(model_type=model_type) + print(f"Installed models of type `{model_type}`:") + for model in models: + path = (config.models_path / model.path).resolve() + print(f"{model.name:40}{model.base.value:5}{model.type.value:8}{model.format.value:12}{path}") # -------------------------------------------------------- -def ask_user_for_prediction_type(model_path: Path, tui_conn: Connection = None) -> SchedulerPredictionType: - if tui_conn: - logger.debug("Waiting for user response...") - return _ask_user_for_pt_tui(model_path, tui_conn) - else: - return _ask_user_for_pt_cmdline(model_path) - - -def _ask_user_for_pt_cmdline(model_path: Path) -> Optional[SchedulerPredictionType]: - choices = [SchedulerPredictionType.Epsilon, SchedulerPredictionType.VPrediction, None] - print( - f""" -Please select the scheduler prediction type of the checkpoint named {model_path.name}: -[1] "epsilon" - most v1.5 models and v2 models trained on 512 pixel images -[2] "vprediction" - v2 models trained on 768 pixel images and a few v1.5 models -[3] Accept the best guess; you can fix it in the Web UI later -""" - ) - choice = None - ok = False - while not ok: - try: - choice = input("select [3]> ").strip() - if not choice: - return None - choice = choices[int(choice) - 1] - ok = True - except (ValueError, IndexError): - print(f"{choice} is not a valid choice") - except EOFError: - return - return choice - - -def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection) -> SchedulerPredictionType: - tui_conn.send_bytes(f"*need v2 config for:{model_path}".encode("utf-8")) - # note that we don't do any status checking here - response = tui_conn.recv_bytes().decode("utf-8") - if response is None: - return None - elif response == "epsilon": - return SchedulerPredictionType.epsilon - elif response == "v": - return SchedulerPredictionType.VPrediction - elif response == "guess": - return None - else: - return None - - -# -------------------------------------------------------- -def process_and_execute( - opt: Namespace, - selections: InstallSelections, - conn_out: Connection = None, -): - # need to reinitialize config in subprocess - config = InvokeAIAppConfig.get_config() - args = ["--root", opt.root] if opt.root else [] - config.parse_args(args) - - # set up so that stderr is sent to conn_out - if conn_out: - translator = StderrToMessage(conn_out) - sys.stderr = translator - sys.stdout = translator - logger = InvokeAILogger.get_logger() - logger.handlers.clear() - logger.addHandler(logging.StreamHandler(translator)) - - installer = ModelInstall(config, prediction_type_helper=lambda x: ask_user_for_prediction_type(x, conn_out)) - installer.install(selections) - - if conn_out: - conn_out.send_bytes("*done*".encode("utf-8")) - conn_out.close() - - -# -------------------------------------------------------- -def select_and_download_models(opt: Namespace): +def select_and_download_models(opt: Namespace) -> None: + """Prompt user for install/delete selections and execute.""" precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device())) + # unsure how to avoid a typing complaint in the next line: config.precision is an enumerated Literal config.precision = precision - installer = ModelInstall(config, prediction_type_helper=ask_user_for_prediction_type) + install_helper = InstallHelper(config, logger) + installer = install_helper.installer + if opt.list_models: - installer.list_models(opt.list_models) + list_models(installer, opt.list_models) + elif opt.add or opt.delete: - selections = InstallSelections(install_models=opt.add or [], remove_models=opt.delete or []) - installer.install(selections) + selections = InstallSelections( + install_models=[UnifiedModelInfo(source=x) for x in (opt.add or [])], remove_models=opt.delete or [] + ) + install_helper.add_or_delete(selections) + elif opt.default_only: - selections = InstallSelections(install_models=installer.default_model()) - installer.install(selections) + default_model = install_helper.default_model() + assert default_model is not None + selections = InstallSelections(install_models=[default_model]) + install_helper.add_or_delete(selections) + elif opt.yes_to_all: - selections = InstallSelections(install_models=installer.recommended_models()) - installer.install(selections) + selections = InstallSelections(install_models=install_helper.recommended_models()) + install_helper.add_or_delete(selections) # this is where the TUI is called else: - # needed to support the probe() method running under a subprocess - torch.multiprocessing.set_start_method("spawn") - if not set_min_terminal_size(MIN_COLS, MIN_LINES): raise WindowTooSmallException( "Could not increase terminal size. Try running again with a larger window or smaller font size." ) - installApp = AddModelApplication(opt) + installApp = AddModelApplication(opt, install_helper) try: installApp.run() - except KeyboardInterrupt as e: - if hasattr(installApp, "main_form"): - if installApp.main_form.subprocess and installApp.main_form.subprocess.is_alive(): - logger.info("Terminating subprocesses") - installApp.main_form.subprocess.terminate() - installApp.main_form.subprocess = None - raise e - process_and_execute(opt, installApp.install_selections) + except KeyboardInterrupt: + print("Aborted...") + sys.exit(-1) + + install_helper.add_or_delete(installApp.install_selections) # ------------------------------------- -def main(): +def main() -> None: parser = argparse.ArgumentParser(description="InvokeAI model downloader") parser.add_argument( "--add", @@ -754,7 +564,7 @@ def main(): parser.add_argument( "--delete", nargs="*", - help="List of names of models to idelete", + help="List of names of models to delete. Use type:name to disambiguate, as in `controlnet:my_model`", ) parser.add_argument( "--full-precision", @@ -781,14 +591,6 @@ def main(): choices=[x.value for x in ModelType], help="list installed models", ) - parser.add_argument( - "--config_file", - "-c", - dest="config_file", - type=str, - default=None, - help="path to configuration file to create", - ) parser.add_argument( "--root_dir", dest="root", diff --git a/invokeai/frontend/install/model_install2.py b/invokeai/frontend/install/model_install.py.OLD similarity index 57% rename from invokeai/frontend/install/model_install2.py rename to invokeai/frontend/install/model_install.py.OLD index 6eb480c8d9d..e23538ffd66 100644 --- a/invokeai/frontend/install/model_install2.py +++ b/invokeai/frontend/install/model_install.py.OLD @@ -6,45 +6,47 @@ """ This is the npyscreen frontend to the model installation application. -It is currently named model_install2.py, but will ultimately replace model_install.py. +The work is actually done in backend code in model_install_backend.py. """ import argparse import curses +import logging import sys +import textwrap import traceback -import warnings from argparse import Namespace +from multiprocessing import Process +from multiprocessing.connection import Connection, Pipe +from pathlib import Path from shutil import get_terminal_size -from typing import Any, Dict, List, Optional, Set +from typing import Optional import npyscreen import torch from npyscreen import widget from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.model_install import ModelInstallService -from invokeai.backend.install.install_helper import InstallHelper, InstallSelections, UnifiedModelInfo -from invokeai.backend.model_manager import ModelType +from invokeai.backend.install.model_install_backend import InstallSelections, ModelInstall, SchedulerPredictionType +from invokeai.backend.model_management import ModelManager, ModelType from invokeai.backend.util import choose_precision, choose_torch_device from invokeai.backend.util.logging import InvokeAILogger from invokeai.frontend.install.widgets import ( MIN_COLS, MIN_LINES, + BufferBox, CenteredTitleText, CyclingForm, MultiSelectColumns, SingleSelectColumns, TextBox, WindowTooSmallException, + select_stable_diffusion_config_file, set_min_terminal_size, ) -warnings.filterwarnings("ignore", category=UserWarning) # noqa: E402 config = InvokeAIAppConfig.get_config() -logger = InvokeAILogger.get_logger("ModelInstallService") -logger.setLevel("WARNING") -# logger.setLevel('DEBUG') +logger = InvokeAILogger.get_logger() # build a table mapping all non-printable characters to None # for stripping control characters @@ -56,42 +58,44 @@ def make_printable(s: str) -> str: - """Replace non-printable characters in a string.""" + """Replace non-printable characters in a string""" return s.translate(NOPRINT_TRANS_TABLE) class addModelsForm(CyclingForm, npyscreen.FormMultiPage): - """Main form for interactive TUI.""" - # for responsive resizing set to False, but this seems to cause a crash! FIX_MINIMUM_SIZE_WHEN_CREATED = True # for persistence current_tab = 0 - def __init__(self, parentApp: npyscreen.NPSAppManaged, name: str, multipage: bool = False, **keywords: Any): + def __init__(self, parentApp, name, multipage=False, *args, **keywords): self.multipage = multipage self.subprocess = None - super().__init__(parentApp=parentApp, name=name, **keywords) + super().__init__(parentApp=parentApp, name=name, *args, **keywords) # noqa: B026 # TODO: maybe this is bad? - def create(self) -> None: - self.installer = self.parentApp.install_helper.installer - self.model_labels = self._get_model_labels() + def create(self): self.keypress_timeout = 10 self.counter = 0 self.subprocess_connection = None + if not config.model_conf_path.exists(): + with open(config.model_conf_path, "w") as file: + print("# InvokeAI model configuration file", file=file) + self.installer = ModelInstall(config) + self.all_models = self.installer.all_models() + self.starter_models = self.installer.starter_models() + self.model_labels = self._get_model_labels() window_width, window_height = get_terminal_size() - # npyscreen has no typing hints - self.nextrely -= 1 # type: ignore + self.nextrely -= 1 self.add_widget_intelligent( npyscreen.FixedText, value="Use ctrl-N and ctrl-P to move to the ext and

revious fields. Cursor keys navigate, and selects.", editable=False, color="CAUTION", ) - self.nextrely += 1 # type: ignore + self.nextrely += 1 self.tabs = self.add_widget_intelligent( SingleSelectColumns, values=[ @@ -111,9 +115,9 @@ def create(self) -> None: ) self.tabs.on_changed = self._toggle_tables - top_of_table = self.nextrely # type: ignore + top_of_table = self.nextrely self.starter_pipelines = self.add_starter_pipelines() - bottom_of_table = self.nextrely # type: ignore + bottom_of_table = self.nextrely self.nextrely = top_of_table self.pipeline_models = self.add_pipeline_widgets( @@ -158,7 +162,15 @@ def create(self) -> None: self.nextrely = bottom_of_table + 1 + self.monitor = self.add_widget_intelligent( + BufferBox, + name="Log Messages", + editable=False, + max_height=6, + ) + self.nextrely += 1 + done_label = "APPLY CHANGES" back_label = "BACK" cancel_label = "CANCEL" current_position = self.nextrely @@ -174,8 +186,14 @@ def create(self) -> None: npyscreen.ButtonPress, name=cancel_label, when_pressed_function=self.on_cancel ) self.nextrely = current_position + self.ok_button = self.add_widget_intelligent( + npyscreen.ButtonPress, + name=done_label, + relx=(window_width - len(done_label)) // 2, + when_pressed_function=self.on_execute, + ) - label = "APPLY CHANGES" + label = "APPLY CHANGES & EXIT" self.nextrely = current_position self.done = self.add_widget_intelligent( npyscreen.ButtonPress, @@ -192,16 +210,17 @@ def create(self) -> None: ############# diffusers tab ########## def add_starter_pipelines(self) -> dict[str, npyscreen.widget]: """Add widgets responsible for selecting diffusers models""" - widgets: Dict[str, npyscreen.widget] = {} + widgets = {} + models = self.all_models + starters = self.starter_models + starter_model_labels = self.model_labels - all_models = self.all_models # master dict of all models, indexed by key - model_list = [x for x in self.starter_models if all_models[x].type in ["main", "vae"]] - model_labels = [self.model_labels[x] for x in model_list] + self.installed_models = sorted([x for x in starters if models[x].installed]) widgets.update( label1=self.add_widget_intelligent( CenteredTitleText, - name="Select from a starter set of Stable Diffusion models from HuggingFace and Civitae.", + name="Select from a starter set of Stable Diffusion models from HuggingFace.", editable=False, labelColor="CAUTION", ) @@ -211,24 +230,23 @@ def add_starter_pipelines(self) -> dict[str, npyscreen.widget]: # if user has already installed some initial models, then don't patronize them # by showing more recommendations show_recommended = len(self.installed_models) == 0 - - checked = [ - model_list.index(x) - for x in model_list - if (show_recommended and all_models[x].recommended) or all_models[x].installed - ] + keys = [x for x in models.keys() if x in starters] widgets.update( models_selected=self.add_widget_intelligent( MultiSelectColumns, columns=1, name="Install Starter Models", - values=model_labels, - value=checked, - max_height=len(model_list) + 1, + values=[starter_model_labels[x] for x in keys], + value=[ + keys.index(x) + for x in keys + if (show_recommended and models[x].recommended) or (x in self.installed_models) + ], + max_height=len(starters) + 1, relx=4, scroll_exit=True, ), - models=model_list, + models=keys, ) self.nextrely += 1 @@ -239,18 +257,14 @@ def add_model_widgets( self, model_type: ModelType, window_width: int = 120, - install_prompt: Optional[str] = None, - exclude: Optional[Set[str]] = None, + install_prompt: str = None, + exclude: set = None, ) -> dict[str, npyscreen.widget]: """Generic code to create model selection widgets""" if exclude is None: exclude = set() - widgets: Dict[str, npyscreen.widget] = {} - all_models = self.all_models - model_list = sorted( - [x for x in all_models if all_models[x].type == model_type and x not in exclude], - key=lambda x: all_models[x].name or "", - ) + widgets = {} + model_list = [x for x in self.all_models if self.all_models[x].model_type == model_type and x not in exclude] model_labels = [self.model_labels[x] for x in model_list] show_recommended = len(self.installed_models) == 0 @@ -286,7 +300,7 @@ def add_model_widgets( value=[ model_list.index(x) for x in model_list - if (show_recommended and all_models[x].recommended) or all_models[x].installed + if (show_recommended and self.all_models[x].recommended) or self.all_models[x].installed ], max_height=len(model_list) // columns + 1, relx=4, @@ -310,7 +324,7 @@ def add_model_widgets( download_ids=self.add_widget_intelligent( TextBox, name="Additional URLs, or HuggingFace repo_ids to install (Space separated. Use shift-control-V to paste):", - max_height=6, + max_height=4, scroll_exit=True, editable=True, ) @@ -335,13 +349,13 @@ def add_pipeline_widgets( return widgets - def resize(self) -> None: + def resize(self): super().resize() if s := self.starter_pipelines.get("models_selected"): - if model_list := self.starter_pipelines.get("models"): - s.values = [self.model_labels[x] for x in model_list] + keys = [x for x in self.all_models.keys() if x in self.starter_models] + s.values = [self.model_labels[x] for x in keys] - def _toggle_tables(self, value: List[int]) -> None: + def _toggle_tables(self, value=None): selected_tab = value[0] widgets = [ self.starter_pipelines, @@ -371,18 +385,17 @@ def _toggle_tables(self, value: List[int]) -> None: self.display() def _get_model_labels(self) -> dict[str, str]: - """Return a list of trimmed labels for all models.""" window_width, window_height = get_terminal_size() checkbox_width = 4 spacing_width = 2 - result = {} models = self.all_models - label_width = max([len(models[x].name or "") for x in self.starter_models]) + label_width = max([len(models[x].name) for x in models]) description_width = window_width - label_width - checkbox_width - spacing_width - for key in self.all_models: - description = models[key].description + result = {} + for x in models.keys(): + description = models[x].description description = ( description[0 : description_width - 3] + "..." if description and len(description) > description_width @@ -390,8 +403,7 @@ def _get_model_labels(self) -> dict[str, str]: if description else "" ) - result[key] = f"%-{label_width}s %s" % (models[key].name, description) - + result[x] = f"%-{label_width}s %s" % (models[x].name, description) return result def _get_columns(self) -> int: @@ -401,40 +413,50 @@ def _get_columns(self) -> int: def confirm_deletions(self, selections: InstallSelections) -> bool: remove_models = selections.remove_models - if remove_models: - model_names = [self.all_models[x].name or "" for x in remove_models] - mods = "\n".join(model_names) - is_ok = npyscreen.notify_ok_cancel( + if len(remove_models) > 0: + mods = "\n".join([ModelManager.parse_key(x)[0] for x in remove_models]) + return npyscreen.notify_ok_cancel( f"These unchecked models will be deleted from disk. Continue?\n---------\n{mods}" ) - assert isinstance(is_ok, bool) # npyscreen doesn't have return type annotations - return is_ok else: return True - @property - def all_models(self) -> Dict[str, UnifiedModelInfo]: - # npyscreen doesn't having typing hints - return self.parentApp.install_helper.all_models # type: ignore + def on_execute(self): + self.marshall_arguments() + app = self.parentApp + if not self.confirm_deletions(app.install_selections): + return - @property - def starter_models(self) -> List[str]: - return self.parentApp.install_helper._starter_models # type: ignore + self.monitor.entry_widget.buffer(["Processing..."], scroll_end=True) + self.ok_button.hidden = True + self.display() - @property - def installed_models(self) -> List[str]: - return self.parentApp.install_helper._installed_models # type: ignore + # TO DO: Spawn a worker thread, not a subprocess + parent_conn, child_conn = Pipe() + p = Process( + target=process_and_execute, + kwargs={ + "opt": app.program_opts, + "selections": app.install_selections, + "conn_out": child_conn, + }, + ) + p.start() + child_conn.close() + self.subprocess_connection = parent_conn + self.subprocess = p + app.install_selections = InstallSelections() - def on_back(self) -> None: + def on_back(self): self.parentApp.switchFormPrevious() self.editing = False - def on_cancel(self) -> None: + def on_cancel(self): self.parentApp.setNextForm(None) self.parentApp.user_cancelled = True self.editing = False - def on_done(self) -> None: + def on_done(self): self.marshall_arguments() if not self.confirm_deletions(self.parentApp.install_selections): return @@ -442,7 +464,77 @@ def on_done(self) -> None: self.parentApp.user_cancelled = False self.editing = False - def marshall_arguments(self) -> None: + ########## This routine monitors the child process that is performing model installation and removal ##### + def while_waiting(self): + """Called during idle periods. Main task is to update the Log Messages box with messages + from the child process that does the actual installation/removal""" + c = self.subprocess_connection + if not c: + return + + monitor_widget = self.monitor.entry_widget + while c.poll(): + try: + data = c.recv_bytes().decode("utf-8") + data.strip("\n") + + # processing child is requesting user input to select the + # right configuration file + if data.startswith("*need v2 config"): + _, model_path, *_ = data.split(":", 2) + self._return_v2_config(model_path) + + # processing child is done + elif data == "*done*": + self._close_subprocess_and_regenerate_form() + break + + # update the log message box + else: + data = make_printable(data) + data = data.replace("[A", "") + monitor_widget.buffer( + textwrap.wrap( + data, + width=monitor_widget.width, + subsequent_indent=" ", + ), + scroll_end=True, + ) + self.display() + except (EOFError, OSError): + self.subprocess_connection = None + + def _return_v2_config(self, model_path: str): + c = self.subprocess_connection + model_name = Path(model_path).name + message = select_stable_diffusion_config_file(model_name=model_name) + c.send_bytes(message.encode("utf-8")) + + def _close_subprocess_and_regenerate_form(self): + app = self.parentApp + self.subprocess_connection.close() + self.subprocess_connection = None + self.monitor.entry_widget.buffer(["** Action Complete **"]) + self.display() + + # rebuild the form, saving and restoring some of the fields that need to be preserved. + saved_messages = self.monitor.entry_widget.values + + app.main_form = app.addForm( + "MAIN", + addModelsForm, + name="Install Stable Diffusion Models", + multipage=self.multipage, + ) + app.switchForm("MAIN") + + app.main_form.monitor.entry_widget.values = saved_messages + app.main_form.monitor.entry_widget.buffer([""], scroll_end=True) + # app.main_form.pipeline_models['autoload_directory'].value = autoload_dir + # app.main_form.pipeline_models['autoscan_on_startup'].value = autoscan + + def marshall_arguments(self): """ Assemble arguments and store as attributes of the application: .starter_models: dict of model names to install from INITIAL_CONFIGURE.yaml @@ -472,24 +564,46 @@ def marshall_arguments(self) -> None: models_to_install = [x for x in selected if not self.all_models[x].installed] models_to_remove = [x for x in section["models"] if x not in selected and self.all_models[x].installed] selections.remove_models.extend(models_to_remove) - selections.install_models.extend([all_models[x] for x in models_to_install]) + selections.install_models.extend( + all_models[x].path or all_models[x].repo_id + for x in models_to_install + if all_models[x].path or all_models[x].repo_id + ) # models located in the 'download_ids" section for section in ui_sections: if downloads := section.get("download_ids"): - models = [UnifiedModelInfo(source=x) for x in downloads.value.split()] - selections.install_models.extend(models) - - -class AddModelApplication(npyscreen.NPSAppManaged): # type: ignore - def __init__(self, opt: Namespace, install_helper: InstallHelper): + selections.install_models.extend(downloads.value.split()) + + # NOT NEEDED - DONE IN BACKEND NOW + # # special case for the ipadapter_models. If any of the adapters are + # # chosen, then we add the corresponding encoder(s) to the install list. + # section = self.ipadapter_models + # if section.get("models_selected"): + # selected_adapters = [ + # self.all_models[section["models"][x]].name for x in section.get("models_selected").value + # ] + # encoders = [] + # if any(["sdxl" in x for x in selected_adapters]): + # encoders.append("ip_adapter_sdxl_image_encoder") + # if any(["sd15" in x for x in selected_adapters]): + # encoders.append("ip_adapter_sd_image_encoder") + # for encoder in encoders: + # key = f"any/clip_vision/{encoder}" + # repo_id = f"InvokeAI/{encoder}" + # if key not in self.all_models: + # selections.install_models.append(repo_id) + + +class AddModelApplication(npyscreen.NPSAppManaged): + def __init__(self, opt): super().__init__() self.program_opts = opt self.user_cancelled = False + # self.autoload_pending = True self.install_selections = InstallSelections() - self.install_helper = install_helper - def onStart(self) -> None: + def onStart(self): npyscreen.setTheme(npyscreen.Themes.DefaultTheme) self.main_form = self.addForm( "MAIN", @@ -499,60 +613,138 @@ def onStart(self) -> None: ) -def list_models(installer: ModelInstallService, model_type: ModelType): - """Print out all models of type model_type.""" - models = installer.record_store.search_by_attr(model_type=model_type) - print(f"Installed models of type `{model_type}`:") - for model in models: - path = (config.models_path / model.path).resolve() - print(f"{model.name:40}{model.base.value:14}{path}") +class StderrToMessage: + def __init__(self, connection: Connection): + self.connection = connection + + def write(self, data: str): + self.connection.send_bytes(data.encode("utf-8")) + + def flush(self): + pass # -------------------------------------------------------- -def select_and_download_models(opt: Namespace) -> None: - """Prompt user for install/delete selections and execute.""" - precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device())) - # unsure how to avoid a typing complaint in the next line: config.precision is an enumerated Literal - config.precision = precision # type: ignore - install_helper = InstallHelper(config, logger) - installer = install_helper.installer +def ask_user_for_prediction_type(model_path: Path, tui_conn: Connection = None) -> SchedulerPredictionType: + if tui_conn: + logger.debug("Waiting for user response...") + return _ask_user_for_pt_tui(model_path, tui_conn) + else: + return _ask_user_for_pt_cmdline(model_path) - if opt.list_models: - list_models(installer, opt.list_models) - elif opt.add or opt.delete: - selections = InstallSelections( - install_models=[UnifiedModelInfo(source=x) for x in (opt.add or [])], remove_models=opt.delete or [] - ) - install_helper.add_or_delete(selections) +def _ask_user_for_pt_cmdline(model_path: Path) -> Optional[SchedulerPredictionType]: + choices = [SchedulerPredictionType.Epsilon, SchedulerPredictionType.VPrediction, None] + print( + f""" +Please select the scheduler prediction type of the checkpoint named {model_path.name}: +[1] "epsilon" - most v1.5 models and v2 models trained on 512 pixel images +[2] "vprediction" - v2 models trained on 768 pixel images and a few v1.5 models +[3] Accept the best guess; you can fix it in the Web UI later +""" + ) + choice = None + ok = False + while not ok: + try: + choice = input("select [3]> ").strip() + if not choice: + return None + choice = choices[int(choice) - 1] + ok = True + except (ValueError, IndexError): + print(f"{choice} is not a valid choice") + except EOFError: + return + return choice + + +def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection) -> SchedulerPredictionType: + tui_conn.send_bytes(f"*need v2 config for:{model_path}".encode("utf-8")) + # note that we don't do any status checking here + response = tui_conn.recv_bytes().decode("utf-8") + if response is None: + return None + elif response == "epsilon": + return SchedulerPredictionType.epsilon + elif response == "v": + return SchedulerPredictionType.VPrediction + elif response == "guess": + return None + else: + return None - elif opt.default_only: - selections = InstallSelections(install_models=[install_helper.default_model()]) - install_helper.add_or_delete(selections) +# -------------------------------------------------------- +def process_and_execute( + opt: Namespace, + selections: InstallSelections, + conn_out: Connection = None, +): + # need to reinitialize config in subprocess + config = InvokeAIAppConfig.get_config() + args = ["--root", opt.root] if opt.root else [] + config.parse_args(args) + + # set up so that stderr is sent to conn_out + if conn_out: + translator = StderrToMessage(conn_out) + sys.stderr = translator + sys.stdout = translator + logger = InvokeAILogger.get_logger() + logger.handlers.clear() + logger.addHandler(logging.StreamHandler(translator)) + + installer = ModelInstall(config, prediction_type_helper=lambda x: ask_user_for_prediction_type(x, conn_out)) + installer.install(selections) + + if conn_out: + conn_out.send_bytes("*done*".encode("utf-8")) + conn_out.close() + + +# -------------------------------------------------------- +def select_and_download_models(opt: Namespace): + precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device())) + config.precision = precision + installer = ModelInstall(config, prediction_type_helper=ask_user_for_prediction_type) + if opt.list_models: + installer.list_models(opt.list_models) + elif opt.add or opt.delete: + selections = InstallSelections(install_models=opt.add or [], remove_models=opt.delete or []) + installer.install(selections) + elif opt.default_only: + selections = InstallSelections(install_models=installer.default_model()) + installer.install(selections) elif opt.yes_to_all: - selections = InstallSelections(install_models=install_helper.recommended_models()) - install_helper.add_or_delete(selections) + selections = InstallSelections(install_models=installer.recommended_models()) + installer.install(selections) # this is where the TUI is called else: + # needed to support the probe() method running under a subprocess + torch.multiprocessing.set_start_method("spawn") + if not set_min_terminal_size(MIN_COLS, MIN_LINES): raise WindowTooSmallException( "Could not increase terminal size. Try running again with a larger window or smaller font size." ) - installApp = AddModelApplication(opt, install_helper) + installApp = AddModelApplication(opt) try: installApp.run() - except KeyboardInterrupt: - print("Aborted...") - sys.exit(-1) - - install_helper.add_or_delete(installApp.install_selections) + except KeyboardInterrupt as e: + if hasattr(installApp, "main_form"): + if installApp.main_form.subprocess and installApp.main_form.subprocess.is_alive(): + logger.info("Terminating subprocesses") + installApp.main_form.subprocess.terminate() + installApp.main_form.subprocess = None + raise e + process_and_execute(opt, installApp.install_selections) # ------------------------------------- -def main() -> None: +def main(): parser = argparse.ArgumentParser(description="InvokeAI model downloader") parser.add_argument( "--add", @@ -562,7 +754,7 @@ def main() -> None: parser.add_argument( "--delete", nargs="*", - help="List of names of models to delete. Use type:name to disambiguate, as in `controlnet:my_model`", + help="List of names of models to idelete", ) parser.add_argument( "--full-precision", @@ -589,6 +781,14 @@ def main() -> None: choices=[x.value for x in ModelType], help="list installed models", ) + parser.add_argument( + "--config_file", + "-c", + dest="config_file", + type=str, + default=None, + help="path to configuration file to create", + ) parser.add_argument( "--root_dir", dest="root", diff --git a/invokeai/frontend/install/widgets.py b/invokeai/frontend/install/widgets.py index 5905ae29dab..4dbc6349a0b 100644 --- a/invokeai/frontend/install/widgets.py +++ b/invokeai/frontend/install/widgets.py @@ -267,6 +267,17 @@ def h_select(self, ch): self.on_changed(self.value) +class CheckboxWithChanged(npyscreen.Checkbox): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.on_changed = None + + def whenToggled(self): + super().whenToggled() + if self.on_changed: + self.on_changed(self.value) + + class SingleSelectColumnsSimple(SelectColumnBase, SingleSelectWithChanged): """Row of radio buttons. Spacebar to select.""" diff --git a/invokeai/frontend/merge/merge_diffusers2.py b/invokeai/frontend/merge/merge_diffusers.py.OLD similarity index 100% rename from invokeai/frontend/merge/merge_diffusers2.py rename to invokeai/frontend/merge/merge_diffusers.py.OLD diff --git a/invokeai/frontend/web/.gitignore b/invokeai/frontend/web/.gitignore index 8e7ebc76a1f..3e8a372bc77 100644 --- a/invokeai/frontend/web/.gitignore +++ b/invokeai/frontend/web/.gitignore @@ -41,3 +41,6 @@ stats.html # Yalc .yalc yalc.lock + +# vitest +tsconfig.vitest-temp.json \ No newline at end of file diff --git a/invokeai/frontend/web/config/common.mts b/invokeai/frontend/web/config/common.mts deleted file mode 100644 index fd559cabd1e..00000000000 --- a/invokeai/frontend/web/config/common.mts +++ /dev/null @@ -1,12 +0,0 @@ -import react from '@vitejs/plugin-react-swc'; -import { visualizer } from 'rollup-plugin-visualizer'; -import type { PluginOption, UserConfig } from 'vite'; -import eslint from 'vite-plugin-eslint'; -import tsconfigPaths from 'vite-tsconfig-paths'; - -export const commonPlugins: UserConfig['plugins'] = [ - react(), - eslint(), - tsconfigPaths(), - visualizer() as unknown as PluginOption, -]; diff --git a/invokeai/frontend/web/config/vite.app.config.mts b/invokeai/frontend/web/config/vite.app.config.mts deleted file mode 100644 index 9683ed26a48..00000000000 --- a/invokeai/frontend/web/config/vite.app.config.mts +++ /dev/null @@ -1,33 +0,0 @@ -import type { UserConfig } from 'vite'; - -import { commonPlugins } from './common.mjs'; - -export const appConfig: UserConfig = { - base: './', - plugins: [...commonPlugins], - build: { - chunkSizeWarningLimit: 1500, - }, - server: { - // Proxy HTTP requests to the flask server - proxy: { - // Proxy socket.io to the nodes socketio server - '/ws/socket.io': { - target: 'ws://127.0.0.1:9090', - ws: true, - }, - // Proxy openapi schema definiton - '/openapi.json': { - target: 'http://127.0.0.1:9090/openapi.json', - rewrite: (path) => path.replace(/^\/openapi.json/, ''), - changeOrigin: true, - }, - // proxy nodes api - '/api/v1': { - target: 'http://127.0.0.1:9090/api/v1', - rewrite: (path) => path.replace(/^\/api\/v1/, ''), - changeOrigin: true, - }, - }, - }, -}; diff --git a/invokeai/frontend/web/config/vite.package.config.mts b/invokeai/frontend/web/config/vite.package.config.mts deleted file mode 100644 index 3c05d52e005..00000000000 --- a/invokeai/frontend/web/config/vite.package.config.mts +++ /dev/null @@ -1,46 +0,0 @@ -import path from 'path'; -import type { UserConfig } from 'vite'; -import cssInjectedByJsPlugin from 'vite-plugin-css-injected-by-js'; -import dts from 'vite-plugin-dts'; - -import { commonPlugins } from './common.mjs'; - -export const packageConfig: UserConfig = { - base: './', - plugins: [ - ...commonPlugins, - dts({ - insertTypesEntry: true, - }), - cssInjectedByJsPlugin(), - ], - build: { - cssCodeSplit: true, - lib: { - entry: path.resolve(__dirname, '../src/index.ts'), - name: 'InvokeAIUI', - fileName: (format) => `invoke-ai-ui.${format}.js`, - }, - rollupOptions: { - external: ['react', 'react-dom', '@emotion/react', '@chakra-ui/react', '@invoke-ai/ui-library'], - output: { - globals: { - react: 'React', - 'react-dom': 'ReactDOM', - '@emotion/react': 'EmotionReact', - '@invoke-ai/ui-library': 'UiLibrary', - }, - }, - }, - }, - resolve: { - alias: { - app: path.resolve(__dirname, '../src/app'), - assets: path.resolve(__dirname, '../src/assets'), - common: path.resolve(__dirname, '../src/common'), - features: path.resolve(__dirname, '../src/features'), - services: path.resolve(__dirname, '../src/services'), - theme: path.resolve(__dirname, '../src/theme'), - }, - }, -}; diff --git a/invokeai/frontend/web/package.json b/invokeai/frontend/web/package.json index cd95183c7a4..cea13350d26 100644 --- a/invokeai/frontend/web/package.json +++ b/invokeai/frontend/web/package.json @@ -33,7 +33,9 @@ "preinstall": "npx only-allow pnpm", "storybook": "storybook dev -p 6006", "build-storybook": "storybook build", - "unimported": "npx unimported" + "unimported": "npx unimported", + "test": "vitest", + "test:no-watch": "vitest --no-watch" }, "madge": { "excludeRegExp": [ @@ -152,12 +154,14 @@ "rollup-plugin-visualizer": "^5.12.0", "storybook": "^7.6.10", "ts-toolbelt": "^9.6.0", + "tsafe": "^1.6.6", "typescript": "^5.3.3", "vite": "^5.0.12", "vite-plugin-css-injected-by-js": "^3.3.1", "vite-plugin-dts": "^3.7.1", "vite-plugin-eslint": "^1.8.1", - "vite-tsconfig-paths": "^4.3.1" + "vite-tsconfig-paths": "^4.3.1", + "vitest": "^1.2.2" }, "pnpm": { "patchedDependencies": { diff --git a/invokeai/frontend/web/pnpm-lock.yaml b/invokeai/frontend/web/pnpm-lock.yaml index 1d9083d1b44..0ec2e47a0cd 100644 --- a/invokeai/frontend/web/pnpm-lock.yaml +++ b/invokeai/frontend/web/pnpm-lock.yaml @@ -215,7 +215,7 @@ devDependencies: version: 7.6.10(react-dom@18.2.0)(react@18.2.0)(typescript@5.3.3)(vite@5.0.12) '@storybook/test': specifier: ^7.6.10 - version: 7.6.10 + version: 7.6.10(vitest@1.2.2) '@storybook/theming': specifier: ^7.6.10 version: 7.6.10(react-dom@18.2.0)(react@18.2.0) @@ -300,6 +300,9 @@ devDependencies: ts-toolbelt: specifier: ^9.6.0 version: 9.6.0 + tsafe: + specifier: ^1.6.6 + version: 1.6.6 typescript: specifier: ^5.3.3 version: 5.3.3 @@ -318,6 +321,9 @@ devDependencies: vite-tsconfig-paths: specifier: ^4.3.1 version: 4.3.1(typescript@5.3.3)(vite@5.0.12) + vitest: + specifier: ^1.2.2 + version: 1.2.2(@types/node@20.11.5) packages: @@ -5464,7 +5470,7 @@ packages: - supports-color dev: true - /@storybook/test@7.6.10: + /@storybook/test@7.6.10(vitest@1.2.2): resolution: {integrity: sha512-dn/T+HcWOBlVh3c74BHurp++BaqBoQgNbSIaXlYDpJoZ+DzNIoEQVsWFYm5gCbtKK27iFd4n52RiQI3f6Vblqw==} dependencies: '@storybook/client-logger': 7.6.10 @@ -5472,7 +5478,7 @@ packages: '@storybook/instrumenter': 7.6.10 '@storybook/preview-api': 7.6.10 '@testing-library/dom': 9.3.4 - '@testing-library/jest-dom': 6.2.0 + '@testing-library/jest-dom': 6.2.0(vitest@1.2.2) '@testing-library/user-event': 14.3.0(@testing-library/dom@9.3.4) '@types/chai': 4.3.11 '@vitest/expect': 0.34.7 @@ -5652,7 +5658,7 @@ packages: pretty-format: 27.5.1 dev: true - /@testing-library/jest-dom@6.2.0: + /@testing-library/jest-dom@6.2.0(vitest@1.2.2): resolution: {integrity: sha512-+BVQlJ9cmEn5RDMUS8c2+TU6giLvzaHZ8sU/x0Jj7fk+6/46wPdwlgOPcpxS17CjcanBi/3VmGMqVr2rmbUmNw==} engines: {node: '>=14', npm: '>=6', yarn: '>=1'} peerDependencies: @@ -5678,6 +5684,7 @@ packages: dom-accessibility-api: 0.6.3 lodash: 4.17.21 redent: 3.0.0 + vitest: 1.2.2(@types/node@20.11.5) dev: true /@testing-library/user-event@14.3.0(@testing-library/dom@9.3.4): @@ -6490,12 +6497,42 @@ packages: chai: 4.4.1 dev: true + /@vitest/expect@1.2.2: + resolution: {integrity: sha512-3jpcdPAD7LwHUUiT2pZTj2U82I2Tcgg2oVPvKxhn6mDI2On6tfvPQTjAI4628GUGDZrCm4Zna9iQHm5cEexOAg==} + dependencies: + '@vitest/spy': 1.2.2 + '@vitest/utils': 1.2.2 + chai: 4.4.1 + dev: true + + /@vitest/runner@1.2.2: + resolution: {integrity: sha512-JctG7QZ4LSDXr5CsUweFgcpEvrcxOV1Gft7uHrvkQ+fsAVylmWQvnaAr/HDp3LAH1fztGMQZugIheTWjaGzYIg==} + dependencies: + '@vitest/utils': 1.2.2 + p-limit: 5.0.0 + pathe: 1.1.2 + dev: true + + /@vitest/snapshot@1.2.2: + resolution: {integrity: sha512-SmGY4saEw1+bwE1th6S/cZmPxz/Q4JWsl7LvbQIky2tKE35US4gd0Mjzqfr84/4OD0tikGWaWdMja/nWL5NIPA==} + dependencies: + magic-string: 0.30.5 + pathe: 1.1.2 + pretty-format: 29.7.0 + dev: true + /@vitest/spy@0.34.7: resolution: {integrity: sha512-NMMSzOY2d8L0mcOt4XcliDOS1ISyGlAXuQtERWVOoVHnKwmG+kKhinAiGw3dTtMQWybfa89FG8Ucg9tiC/FhTQ==} dependencies: tinyspy: 2.2.0 dev: true + /@vitest/spy@1.2.2: + resolution: {integrity: sha512-k9Gcahssw8d7X3pSLq3e3XEu/0L78mUkCjivUqCQeXJm9clfXR/Td8+AP+VC1O6fKPIDLcHDTAmBOINVuv6+7g==} + dependencies: + tinyspy: 2.2.0 + dev: true + /@vitest/utils@0.34.7: resolution: {integrity: sha512-ziAavQLpCYS9sLOorGrFFKmy2gnfiNU0ZJ15TsMz/K92NAPS/rp9K4z6AJQQk5Y8adCy4Iwpxy7pQumQ/psnRg==} dependencies: @@ -6504,6 +6541,15 @@ packages: pretty-format: 29.7.0 dev: true + /@vitest/utils@1.2.2: + resolution: {integrity: sha512-WKITBHLsBHlpjnDQahr+XK6RE7MiAsgrIkr0pGhQ9ygoxBfUeG0lUG5iLlzqjmKSlBv3+j5EGsriBzh+C3Tq9g==} + dependencies: + diff-sequences: 29.6.3 + estree-walker: 3.0.3 + loupe: 2.3.7 + pretty-format: 29.7.0 + dev: true + /@volar/language-core@1.11.1: resolution: {integrity: sha512-dOcNn3i9GgZAcJt43wuaEykSluAuOkQgzni1cuxLxTV0nJKanQztp7FxyswdRILaKH+P2XZMPRp2S4MV/pElCw==} dependencies: @@ -7184,6 +7230,11 @@ packages: engines: {node: '>=0.4.0'} dev: true + /acorn-walk@8.3.2: + resolution: {integrity: sha512-cjkyv4OtNCIeqhHrfS81QWXoCBPExR/J62oyEqepVw8WaQeSqpW2uhuLPh1m9eWhDuOo/jUXVTlifvesOWp/4A==} + engines: {node: '>=0.4.0'} + dev: true + /acorn@7.4.1: resolution: {integrity: sha512-nQyp0o1/mNdbTO1PO6kHkwSrmgZ0MT/jCCpNiwbUjGoRN4dlBhqJtoQuCnEOKzgTVwg0ZWiCoQy6SxMebQVh8A==} engines: {node: '>=0.4.0'} @@ -7661,6 +7712,11 @@ packages: engines: {node: '>= 0.8'} dev: true + /cac@6.7.14: + resolution: {integrity: sha512-b6Ilus+c3RrdDk+JhLKUAQfzzgLEPy6wcXqS7f/xe1EETvsDP6GORG7SFuOs6cID5YkqchW/LXZbX5bc8j7ZcQ==} + engines: {node: '>=8'} + dev: true + /call-bind@1.0.5: resolution: {integrity: sha512-C3nQxfFZxFRVoJoGKKI8y3MOEo129NQ+FgQ08iye+Mk4zNZZGdjfs06bVTr+DBSlA66Q2VEcMki/cUCP4SercQ==} dependencies: @@ -9173,6 +9229,12 @@ packages: resolution: {integrity: sha512-Rfkk/Mp/DL7JVje3u18FxFujQlTNR2q6QfMSMB7AvCBx91NGj/ba3kCfza0f6dVDbw7YlRf/nDrn7pQrCCyQ/w==} dev: true + /estree-walker@3.0.3: + resolution: {integrity: sha512-7RUKfXgSMMkzt6ZuXmqapOurLGPPfgj6l9uRZ7lRGolvk0y2yocc35LdcxKC5PQZdn2DMqioAQ2NoWcrTKmm6g==} + dependencies: + '@types/estree': 1.0.5 + dev: true + /esutils@2.0.3: resolution: {integrity: sha512-kVscqXk4OCp68SZ0dkgEKVi6/8ij300KBWTJq32P/dYeWTSwK41WyTxalN1eRmA5Z9UU/LX9D7FWSmV9SAYx6g==} engines: {node: '>=0.10.0'} @@ -10547,6 +10609,10 @@ packages: hasBin: true dev: true + /jsonc-parser@3.2.1: + resolution: {integrity: sha512-AilxAyFOAcK5wA1+LeaySVBrHsGQvUFCDWXKpZjzaL0PqW+xfBOttn8GNtWKFWqneyMZj41MWF9Kl6iPWLwgOA==} + dev: true + /jsondiffpatch@0.6.0: resolution: {integrity: sha512-3QItJOXp2AP1uv7waBkao5nCvhEv+QmJAd38Ybq7wNI74Q+BBmnLn4EDKz6yI9xGAIQoUF87qHt+kc1IVxB4zQ==} engines: {node: ^18.0.0 || >=20.0.0} @@ -10648,6 +10714,14 @@ packages: engines: {node: ^12.20.0 || ^14.13.1 || >=16.0.0} dev: true + /local-pkg@0.5.0: + resolution: {integrity: sha512-ok6z3qlYyCDS4ZEU27HaU6x/xZa9Whf8jD4ptH5UZTQYZVYeb9bnZ3ojVhiJNLiXK1Hfc0GNbLXcmZ5plLDDBg==} + engines: {node: '>=14'} + dependencies: + mlly: 1.5.0 + pkg-types: 1.0.3 + dev: true + /locate-path@3.0.0: resolution: {integrity: sha512-7AO748wWnIhNqAuaty2ZWHkQHRSNfPVIsPIfwEOWO22AmaoVrWavlOcMR5nzTLNYvp36X220/maaRsrec1G65A==} engines: {node: '>=6'} @@ -10986,6 +11060,15 @@ packages: hasBin: true dev: true + /mlly@1.5.0: + resolution: {integrity: sha512-NPVQvAY1xr1QoVeG0cy8yUYC7FQcOx6evl/RjT1wL5FvzPnzOysoqB/jmx/DhssT2dYa8nxECLAaFI/+gVLhDQ==} + dependencies: + acorn: 8.11.3 + pathe: 1.1.2 + pkg-types: 1.0.3 + ufo: 1.3.2 + dev: true + /module-definition@3.4.0: resolution: {integrity: sha512-XxJ88R1v458pifaSkPNLUTdSPNVGMP2SXVncVmApGO+gAfrLANiYe6JofymCzVceGOMwQE2xogxBSc8uB7XegA==} engines: {node: '>=6.0'} @@ -11380,6 +11463,13 @@ packages: yocto-queue: 0.1.0 dev: true + /p-limit@5.0.0: + resolution: {integrity: sha512-/Eaoq+QyLSiXQ4lyYV23f14mZRQcXnxfHrN0vCai+ak9G0pp9iEQukIIZq5NccEvwRB8PUnZT0KsOoDCINS1qQ==} + engines: {node: '>=18'} + dependencies: + yocto-queue: 1.0.0 + dev: true + /p-locate@3.0.0: resolution: {integrity: sha512-x+12w/To+4GFfgJhBEpiDcLozRJGegY+Ei7/z0tSLkMmxGZNybVMSfWj9aJn8Z5Fc7dBUNJOOVgPv2H7IwulSQ==} engines: {node: '>=6'} @@ -11550,6 +11640,14 @@ packages: find-up: 5.0.0 dev: true + /pkg-types@1.0.3: + resolution: {integrity: sha512-nN7pYi0AQqJnoLPC9eHFQ8AcyaixBUOwvqc5TDnIKCMEE6I0y8P7OKA7fPexsXGCGxQDl/cmrLAp26LhcwxZ4A==} + dependencies: + jsonc-parser: 3.2.1 + mlly: 1.5.0 + pathe: 1.1.2 + dev: true + /pluralize@8.0.0: resolution: {integrity: sha512-Nc3IT5yHzflTfbjgqWcCPpo7DaKy4FnpB0l/zCAW0Tc7jxAiuqSxHasntB3D7887LSrA93kDJ9IXovxJYxyLCA==} engines: {node: '>=4'} @@ -12850,6 +12948,10 @@ packages: object-inspect: 1.13.1 dev: true + /siginfo@2.0.0: + resolution: {integrity: sha512-ybx0WO1/8bSBLEWXZvEd7gMW3Sn3JFlW3TvX1nREbDLRNQNaeNN8WK0meBwPdAaOI7TtRRRJn/Es1zhrrCHu7g==} + dev: true + /signal-exit@3.0.7: resolution: {integrity: sha512-wnD2ZE+l+SPC/uoS0vXeE9L1+0wuaMqKlfz9AMUo38JsyLSBWSFcHR1Rri62LZc12vLr1gb3jl7iwQhgwpAbGQ==} dev: true @@ -12968,6 +13070,10 @@ packages: stackframe: 1.3.4 dev: false + /stackback@0.0.2: + resolution: {integrity: sha512-1XMJE5fQo1jGH6Y/7ebnwPOBEkIEnT4QF32d5R1+VXdXveM0IBMJt8zfaxX1P3QhVwrYe+576+jkANtSS2mBbw==} + dev: true + /stackframe@1.3.4: resolution: {integrity: sha512-oeVtt7eWQS+Na6F//S4kJ2K2VbRlS9D43mAlMyVpVWovy9o+jfgH8O9agzANzaiLjclA0oYzUXEM4PurhSUChw==} dev: false @@ -12992,6 +13098,10 @@ packages: engines: {node: '>= 0.8'} dev: true + /std-env@3.7.0: + resolution: {integrity: sha512-JPbdCEQLj1w5GilpiHAx3qJvFndqybBysA3qUOnznweH4QbNYUsW/ea8QzSrnh0vNsezMMw5bcVool8lM0gwzg==} + dev: true + /stop-iteration-iterator@1.0.0: resolution: {integrity: sha512-iCGQj+0l0HOdZ2AEeBADlsRC+vsnDsZsbdSiH1yNSjcfKM7fdpCMfqAL/dwF5BLiw/XhRft/Wax6zQbhq2BcjQ==} engines: {node: '>= 0.4'} @@ -13161,6 +13271,12 @@ packages: engines: {node: '>=8'} dev: true + /strip-literal@1.3.0: + resolution: {integrity: sha512-PugKzOsyXpArk0yWmUwqOZecSO0GH0bPoctLcqNDH9J04pVW3lflYE0ujElBGTloevcxF5MofAOZ7C5l2b+wLg==} + dependencies: + acorn: 8.11.3 + dev: true + /stylis@4.2.0: resolution: {integrity: sha512-Orov6g6BB1sDfYgzWfTHDOxamtX1bE/zo104Dh9e6fqJ3PooipYyfJ0pUmrZO2wAvO8YbEyeFrkV91XTsGMSrw==} dev: false @@ -13311,6 +13427,15 @@ packages: /tiny-invariant@1.3.1: resolution: {integrity: sha512-AD5ih2NlSssTCwsMznbvwMZpJ1cbhkGd2uueNxzv2jDlEeZdU04JQfRnggJQ8DrcVBGjAsCKwFBbDlVNtEMlzw==} + /tinybench@2.6.0: + resolution: {integrity: sha512-N8hW3PG/3aOoZAN5V/NSAEDz0ZixDSSt5b/a05iqtpgfLWMSVuCo7w0k2vVvEjdrIoeGqZzweX2WlyioNIHchA==} + dev: true + + /tinypool@0.8.2: + resolution: {integrity: sha512-SUszKYe5wgsxnNOVlBYO6IC+8VGWdVGZWAqUxp3UErNBtptZvWbwyUOyzNL59zigz2rCA92QiL3wvG+JDSdJdQ==} + engines: {node: '>=14.0.0'} + dev: true + /tinyspy@2.2.0: resolution: {integrity: sha512-d2eda04AN/cPOR89F7Xv5bK/jrQEhmcLFe6HFldoeO9AJtps+fqEnh486vnT/8y4bw38pSyxDcTCAq+Ks2aJTg==} engines: {node: '>=14.0.0'} @@ -13383,6 +13508,10 @@ packages: resolution: {integrity: sha512-nsZd8ZeNUzukXPlJmTBwUAuABDe/9qtVDelJeT/qW0ow3ZS3BsQJtNkan1802aM9Uf68/Y8ljw86Hu0h5IUW3w==} dev: true + /tsafe@1.6.6: + resolution: {integrity: sha512-gzkapsdbMNwBnTIjgO758GujLCj031IgHK/PKr2mrmkCSJMhSOR5FeOuSxKLMUoYc0vAA4RGEYYbjt/v6afD3g==} + dev: true + /tsconfck@3.0.1(typescript@5.3.3): resolution: {integrity: sha512-7ppiBlF3UEddCLeI1JRx5m2Ryq+xk4JrZuq4EuYXykipebaq1dV0Fhgr1hb7CkmHt32QSgOZlcqVLEtHBG4/mg==} engines: {node: ^18 || >=20} @@ -13828,6 +13957,27 @@ packages: engines: {node: '>= 0.8'} dev: true + /vite-node@1.2.2(@types/node@20.11.5): + resolution: {integrity: sha512-1as4rDTgVWJO3n1uHmUYqq7nsFgINQ9u+mRcXpjeOMJUmviqNKjcZB7UfRZrlM7MjYXMKpuWp5oGkjaFLnjawg==} + engines: {node: ^18.0.0 || >=20.0.0} + hasBin: true + dependencies: + cac: 6.7.14 + debug: 4.3.4 + pathe: 1.1.2 + picocolors: 1.0.0 + vite: 5.0.12(@types/node@20.11.5) + transitivePeerDependencies: + - '@types/node' + - less + - lightningcss + - sass + - stylus + - sugarss + - supports-color + - terser + dev: true + /vite-plugin-css-injected-by-js@3.3.1(vite@5.0.12): resolution: {integrity: sha512-PjM/X45DR3/V1K1fTRs8HtZHEQ55kIfdrn+dzaqNBFrOYO073SeSNCxp4j7gSYhV9NffVHaEnOL4myoko0ePAg==} peerDependencies: @@ -13926,6 +14076,63 @@ packages: fsevents: 2.3.3 dev: true + /vitest@1.2.2(@types/node@20.11.5): + resolution: {integrity: sha512-d5Ouvrnms3GD9USIK36KG8OZ5bEvKEkITFtnGv56HFaSlbItJuYr7hv2Lkn903+AvRAgSixiamozUVfORUekjw==} + engines: {node: ^18.0.0 || >=20.0.0} + hasBin: true + peerDependencies: + '@edge-runtime/vm': '*' + '@types/node': ^18.0.0 || >=20.0.0 + '@vitest/browser': ^1.0.0 + '@vitest/ui': ^1.0.0 + happy-dom: '*' + jsdom: '*' + peerDependenciesMeta: + '@edge-runtime/vm': + optional: true + '@types/node': + optional: true + '@vitest/browser': + optional: true + '@vitest/ui': + optional: true + happy-dom: + optional: true + jsdom: + optional: true + dependencies: + '@types/node': 20.11.5 + '@vitest/expect': 1.2.2 + '@vitest/runner': 1.2.2 + '@vitest/snapshot': 1.2.2 + '@vitest/spy': 1.2.2 + '@vitest/utils': 1.2.2 + acorn-walk: 8.3.2 + cac: 6.7.14 + chai: 4.4.1 + debug: 4.3.4 + execa: 8.0.1 + local-pkg: 0.5.0 + magic-string: 0.30.5 + pathe: 1.1.2 + picocolors: 1.0.0 + std-env: 3.7.0 + strip-literal: 1.3.0 + tinybench: 2.6.0 + tinypool: 0.8.2 + vite: 5.0.12(@types/node@20.11.5) + vite-node: 1.2.2(@types/node@20.11.5) + why-is-node-running: 2.2.2 + transitivePeerDependencies: + - less + - lightningcss + - sass + - stylus + - sugarss + - supports-color + - terser + dev: true + /void-elements@3.1.0: resolution: {integrity: sha512-Dhxzh5HZuiHQhbvTW9AMetFfBHDMYpo23Uo9btPXgdYP+3T5S+p+jgNy7spra+veYhBP2dCSgxR/i2Y02h5/6w==} engines: {node: '>=0.10.0'} @@ -14049,6 +14256,15 @@ packages: isexe: 2.0.0 dev: true + /why-is-node-running@2.2.2: + resolution: {integrity: sha512-6tSwToZxTOcotxHeA+qGCq1mVzKR3CwcJGmVcY+QE8SHy6TnpFnh8PAvPNHYr7EcuVeG0QSMxtYCuO1ta/G/oA==} + engines: {node: '>=8'} + hasBin: true + dependencies: + siginfo: 2.0.0 + stackback: 0.0.2 + dev: true + /wordwrap@1.0.0: resolution: {integrity: sha512-gvVzJFlPycKc5dZN4yPkP8w7Dc37BtP1yczEneOb4uq34pXZcvrtRTmWV8W+Ume+XCxKgbjM+nevkyFPMybd4Q==} dev: true @@ -14189,6 +14405,11 @@ packages: engines: {node: '>=10'} dev: true + /yocto-queue@1.0.0: + resolution: {integrity: sha512-9bnSc/HEW2uRy67wc+T8UwauLuPJVn28jb+GtJY16iiKWyvmYJRXVT4UamsAEGQfPohgr2q4Tq0sQbQlxTfi1g==} + engines: {node: '>=12.20'} + dev: true + /z-schema@5.0.5: resolution: {integrity: sha512-D7eujBWkLa3p2sIpJA0d1pr7es+a7m0vFAnZLlCEKq/Ij2k0MLi9Br2UPxoxdYystm5K1yeBGzub0FlYUEWj2Q==} engines: {node: '>=8.0.0'} diff --git a/invokeai/frontend/web/src/app/store/middleware/devtools/actionSanitizer.ts b/invokeai/frontend/web/src/app/store/middleware/devtools/actionSanitizer.ts index 2e2d2014b23..ed8c82d91ca 100644 --- a/invokeai/frontend/web/src/app/store/middleware/devtools/actionSanitizer.ts +++ b/invokeai/frontend/web/src/app/store/middleware/devtools/actionSanitizer.ts @@ -1,6 +1,6 @@ import type { UnknownAction } from '@reduxjs/toolkit'; import { isAnyGraphBuilt } from 'features/nodes/store/actions'; -import { nodeTemplatesBuilt } from 'features/nodes/store/nodeTemplatesSlice'; +import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice'; import { cloneDeep } from 'lodash-es'; import { appInfoApi } from 'services/api/endpoints/appInfo'; import type { Graph } from 'services/api/types'; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema.ts index b2d36159098..88518e2c0bb 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema.ts @@ -1,6 +1,6 @@ import { logger } from 'app/logging/logger'; import { parseify } from 'common/util/serialize'; -import { nodeTemplatesBuilt } from 'features/nodes/store/nodeTemplatesSlice'; +import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice'; import { parseSchema } from 'features/nodes/util/schema/parseSchema'; import { size } from 'lodash-es'; import { appInfoApi } from 'services/api/endpoints/appInfo'; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts index d49f35cd2ab..75fa9e10949 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts @@ -4,7 +4,7 @@ import { addImageToStagingArea } from 'features/canvas/store/canvasSlice'; import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice'; import { IMAGE_CATEGORIES } from 'features/gallery/store/types'; import { isImageOutput } from 'features/nodes/types/common'; -import { LINEAR_UI_OUTPUT, nodeIDDenyList } from 'features/nodes/util/graph/constants'; +import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants'; import { boardsApi } from 'services/api/endpoints/boards'; import { imagesApi } from 'services/api/endpoints/images'; import { imagesAdapter } from 'services/api/util'; @@ -24,10 +24,9 @@ export const addInvocationCompleteEventListener = () => { const { data } = action.payload; log.debug({ data: parseify(data) }, `Invocation complete (${action.payload.data.node.type})`); - const { result, node, queue_batch_id, source_node_id } = data; - + const { result, node, queue_batch_id } = data; // This complete event has an associated image output - if (isImageOutput(result) && !nodeTypeDenylist.includes(node.type) && !nodeIDDenyList.includes(source_node_id)) { + if (isImageOutput(result) && !nodeTypeDenylist.includes(node.type)) { const { image_name } = result.image; const { canvas, gallery } = getState(); @@ -42,7 +41,7 @@ export const addInvocationCompleteEventListener = () => { imageDTORequest.unsubscribe(); // Add canvas images to the staging area - if (canvas.batchIds.includes(queue_batch_id) && [LINEAR_UI_OUTPUT].includes(data.source_node_id)) { + if (canvas.batchIds.includes(queue_batch_id) && data.source_node_id === CANVAS_OUTPUT) { dispatch(addImageToStagingArea(imageDTO)); } diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts index 752c3b09df2..ac1298da5ba 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts @@ -15,8 +15,7 @@ export const addUpdateAllNodesRequestedListener = () => { actionCreator: updateAllNodesRequested, effect: (action, { dispatch, getState }) => { const log = logger('nodes'); - const nodes = getState().nodes.nodes; - const templates = getState().nodeTemplates.templates; + const { nodes, templates } = getState().nodes; let unableToUpdateCount = 0; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/upscaleRequested.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/upscaleRequested.ts index 46f55ef21ff..ab989301796 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/upscaleRequested.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/upscaleRequested.ts @@ -39,16 +39,12 @@ export const addUpscaleRequestedListener = () => { return; } - const { esrganModelName } = state.postprocessing; - const { autoAddBoardId } = state.gallery; - const enqueueBatchArg: BatchConfig = { prepend: true, batch: { graph: buildAdHocUpscaleGraph({ image_name, - esrganModelName, - autoAddBoardId, + state, }), runs: 1, }, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts index 9307031e6d0..ad41dc2654f 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts @@ -18,7 +18,7 @@ export const addWorkflowLoadRequestedListener = () => { effect: (action, { dispatch, getState }) => { const log = logger('nodes'); const { workflow, asCopy } = action.payload; - const nodeTemplates = getState().nodeTemplates.templates; + const nodeTemplates = getState().nodes.templates; try { const { workflow: validatedWorkflow, warnings } = validateWorkflow(workflow, nodeTemplates); diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts index e25e1351eb9..270662c3d21 100644 --- a/invokeai/frontend/web/src/app/store/store.ts +++ b/invokeai/frontend/web/src/app/store/store.ts @@ -16,7 +16,6 @@ import { hrfPersistConfig, hrfSlice } from 'features/hrf/store/hrfSlice'; import { loraPersistConfig, loraSlice } from 'features/lora/store/loraSlice'; import { modelManagerPersistConfig, modelManagerSlice } from 'features/modelManager/store/modelManagerSlice'; import { nodesPersistConfig, nodesSlice } from 'features/nodes/store/nodesSlice'; -import { nodesTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; import { workflowPersistConfig, workflowSlice } from 'features/nodes/store/workflowSlice'; import { generationPersistConfig, generationSlice } from 'features/parameters/store/generationSlice'; import { postprocessingPersistConfig, postprocessingSlice } from 'features/parameters/store/postprocessingSlice'; @@ -46,7 +45,6 @@ const allReducers = { [gallerySlice.name]: gallerySlice.reducer, [generationSlice.name]: generationSlice.reducer, [nodesSlice.name]: nodesSlice.reducer, - [nodesTemplatesSlice.name]: nodesTemplatesSlice.reducer, [postprocessingSlice.name]: postprocessingSlice.reducer, [systemSlice.name]: systemSlice.reducer, [configSlice.name]: configSlice.reducer, diff --git a/invokeai/frontend/web/src/app/store/storeHooks.ts b/invokeai/frontend/web/src/app/store/storeHooks.ts index f1a9aa979c0..6bc904acb31 100644 --- a/invokeai/frontend/web/src/app/store/storeHooks.ts +++ b/invokeai/frontend/web/src/app/store/storeHooks.ts @@ -1,7 +1,8 @@ import type { AppThunkDispatch, RootState } from 'app/store/store'; import type { TypedUseSelectorHook } from 'react-redux'; -import { useDispatch, useSelector } from 'react-redux'; +import { useDispatch, useSelector, useStore } from 'react-redux'; // Use throughout your app instead of plain `useDispatch` and `useSelector` export const useAppDispatch = () => useDispatch(); export const useAppSelector: TypedUseSelectorHook = useSelector; +export const useAppStore = () => useStore(); diff --git a/invokeai/frontend/web/src/app/store/util.ts b/invokeai/frontend/web/src/app/store/util.ts new file mode 100644 index 00000000000..381f7f85d26 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/util.ts @@ -0,0 +1,2 @@ +export const EMPTY_ARRAY = []; +export const EMPTY_OBJECT = {}; diff --git a/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts b/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts index 4952fa1c47b..baa704e75ca 100644 --- a/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts +++ b/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts @@ -8,7 +8,6 @@ import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types'; import { selectDynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice'; import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { selectGenerationSlice } from 'features/parameters/store/generationSlice'; import { selectSystemSlice } from 'features/system/store/systemSlice'; @@ -23,11 +22,10 @@ const selector = createMemoizedSelector( selectGenerationSlice, selectSystemSlice, selectNodesSlice, - selectNodeTemplatesSlice, selectDynamicPromptsSlice, activeTabNameSelector, ], - (controlAdapters, generation, system, nodes, nodeTemplates, dynamicPrompts, activeTabName) => { + (controlAdapters, generation, system, nodes, dynamicPrompts, activeTabName) => { const { initialImage, model, positivePrompt } = generation; const { isConnected } = system; @@ -54,7 +52,7 @@ const selector = createMemoizedSelector( return; } - const nodeTemplate = nodeTemplates.templates[node.data.type]; + const nodeTemplate = nodes.templates[node.data.type]; if (!nodeTemplate) { // Node type not found diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx index b24b52c6abf..061209cafc0 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx @@ -7,8 +7,12 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import type { SelectInstance } from 'chakra-react-select'; import { useBuildNode } from 'features/nodes/hooks/useBuildNode'; -import { addNodePopoverClosed, addNodePopoverOpened, nodeAdded } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; +import { + addNodePopoverClosed, + addNodePopoverOpened, + nodeAdded, + selectNodesSlice, +} from 'features/nodes/store/nodesSlice'; import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes'; import { filter, map, memoize, some } from 'lodash-es'; import type { KeyboardEventHandler } from 'react'; @@ -54,10 +58,10 @@ const AddNodePopover = () => { const fieldFilter = useAppSelector((s) => s.nodes.connectionStartFieldType); const handleFilter = useAppSelector((s) => s.nodes.connectionStartParams?.handleType); - const selector = createMemoizedSelector(selectNodeTemplatesSlice, (nodeTemplates) => { + const selector = createMemoizedSelector(selectNodesSlice, (nodes) => { // If we have a connection in progress, we need to filter the node choices const filteredNodeTemplates = fieldFilter - ? filter(nodeTemplates.templates, (template) => { + ? filter(nodes.templates, (template) => { const handles = handleFilter === 'source' ? template.inputs : template.outputs; return some(handles, (handle) => { @@ -67,7 +71,7 @@ const AddNodePopover = () => { return validateSourceAndTargetTypes(sourceType, targetType); }); }) - : map(nodeTemplates.templates); + : map(nodes.templates); const options: ComboboxOption[] = map(filteredNodeTemplates, (template) => { return { diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts index 4bfc588e675..ba40b4984cd 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts @@ -1,10 +1,17 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; +import { selectFieldOutputTemplate } from 'features/nodes/store/selectors'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { getFieldColor } from './getEdgeColor'; +const defaultReturnValue = { + isSelected: false, + shouldAnimate: false, + stroke: colorTokenToCssVar('base.500'), +}; + export const makeEdgeSelector = ( source: string, sourceHandleId: string | null | undefined, @@ -12,14 +19,19 @@ export const makeEdgeSelector = ( targetHandleId: string | null | undefined, selected?: boolean ) => - createMemoizedSelector(selectNodesSlice, (nodes) => { + createMemoizedSelector(selectNodesSlice, (nodes): { isSelected: boolean; shouldAnimate: boolean; stroke: string } => { const sourceNode = nodes.nodes.find((node) => node.id === source); const targetNode = nodes.nodes.find((node) => node.id === target); const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode); - const isSelected = sourceNode?.selected || targetNode?.selected || selected; - const sourceType = isInvocationToInvocationEdge ? sourceNode?.data?.outputs[sourceHandleId || '']?.type : undefined; + const isSelected = Boolean(sourceNode?.selected || targetNode?.selected || selected); + if (!sourceNode || !sourceHandleId) { + return defaultReturnValue; + } + + const outputFieldTemplate = selectFieldOutputTemplate(nodes, sourceNode.id, sourceHandleId); + const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined; const stroke = sourceType && nodes.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500'); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeCollapsedHandles.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeCollapsedHandles.tsx index c287842f6ed..b888e8a5162 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeCollapsedHandles.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeCollapsedHandles.tsx @@ -1,6 +1,5 @@ import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens'; -import { useNodeData } from 'features/nodes/hooks/useNodeData'; -import { isInvocationNodeData } from 'features/nodes/types/invocation'; +import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate'; import { map } from 'lodash-es'; import type { CSSProperties } from 'react'; import { memo, useMemo } from 'react'; @@ -13,7 +12,7 @@ interface Props { const hiddenHandleStyles: CSSProperties = { visibility: 'hidden' }; const InvocationNodeCollapsedHandles = ({ nodeId }: Props) => { - const data = useNodeData(nodeId); + const template = useNodeTemplate(nodeId); const { base600 } = useChakraThemeTokens(); const dummyHandleStyles: CSSProperties = useMemo( @@ -37,7 +36,7 @@ const InvocationNodeCollapsedHandles = ({ nodeId }: Props) => { [dummyHandleStyles] ); - if (!isInvocationNodeData(data)) { + if (!template) { return null; } @@ -45,14 +44,14 @@ const InvocationNodeCollapsedHandles = ({ nodeId }: Props) => { <> - {map(data.inputs, (input) => ( + {map(template.inputs, (input) => ( { ))} - {map(data.outputs, (output) => ( + {map(template.outputs, (output) => ( ) => { const { id: nodeId, type, isOpen, label } = data; const hasTemplateSelector = useMemo( - () => createSelector(selectNodeTemplatesSlice, (nodeTemplates) => Boolean(nodeTemplates.templates[type])), + () => createSelector(selectNodesSlice, (nodes) => Boolean(nodes.templates[type])), [type] ); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/EditableFieldTitle.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/EditableFieldTitle.tsx index c2231f703ab..e02b1a1474e 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/EditableFieldTitle.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/EditableFieldTitle.tsx @@ -22,7 +22,7 @@ import FieldTooltipContent from './FieldTooltipContent'; interface Props { nodeId: string; fieldName: string; - kind: 'input' | 'output'; + kind: 'inputs' | 'outputs'; isMissingInput?: boolean; withTooltip?: boolean; } @@ -58,7 +58,7 @@ const EditableFieldTitle = forwardRef((props: Props, ref) => { return ( : undefined} + label={withTooltip ? : undefined} openDelay={HANDLE_TOOLTIP_OPEN_DELAY} > { - const field = useFieldInstance(nodeId, fieldName); + const field = useFieldInputInstance(nodeId, fieldName); const fieldTemplate = useFieldTemplate(nodeId, fieldName, kind); const isInputTemplate = isFieldInputTemplate(fieldTemplate); const fieldTypeName = useFieldTypeName(fieldTemplate?.type); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx index 2b9f7960e4b..66b0d3f7556 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx @@ -25,7 +25,7 @@ const InputField = ({ nodeId, fieldName }: Props) => { const [isHovered, setIsHovered] = useState(false); const { isConnected, isConnectionInProgress, isConnectionStartField, connectionError, shouldDim } = - useConnectionState({ nodeId, fieldName, kind: 'input' }); + useConnectionState({ nodeId, fieldName, kind: 'inputs' }); const isMissingInput = useMemo(() => { if (!fieldTemplate) { @@ -76,7 +76,7 @@ const InputField = ({ nodeId, fieldName }: Props) => { @@ -101,7 +101,7 @@ const InputField = ({ nodeId, fieldName }: Props) => { diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx index c1d52c1d4fb..b6e331c1149 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx @@ -1,6 +1,5 @@ -import { Box, Text } from '@invoke-ai/ui-library'; -import { useFieldInstance } from 'features/nodes/hooks/useFieldData'; -import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate'; +import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance'; +import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate'; import { isBoardFieldInputInstance, isBoardFieldInputTemplate, @@ -38,7 +37,6 @@ import { isVAEModelFieldInputTemplate, } from 'features/nodes/types/field'; import { memo } from 'react'; -import { useTranslation } from 'react-i18next'; import BoardFieldInputComponent from './inputs/BoardFieldInputComponent'; import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent'; @@ -63,17 +61,8 @@ type InputFieldProps = { }; const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => { - const { t } = useTranslation(); - const fieldInstance = useFieldInstance(nodeId, fieldName); - const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input'); - - if (fieldTemplate?.fieldKind === 'output') { - return ( - - {t('nodes.outputFieldInInput')}: {fieldInstance?.type.name} - - ); - } + const fieldInstance = useFieldInputInstance(nodeId, fieldName); + const fieldTemplate = useFieldInputTemplate(nodeId, fieldName); if (isStringFieldInputInstance(fieldInstance) && isStringFieldInputTemplate(fieldTemplate)) { return ; @@ -141,18 +130,10 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => { return ; } - if (fieldInstance && fieldTemplate) { + if (fieldTemplate) { // Fallback for when there is no component for the type return null; } - - return ( - - - {t('nodes.unknownFieldType', { type: fieldInstance?.type.name })} - - - ); }; export default memo(InputFieldRenderer); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx index d0a30ecc3c7..0cd199f7a47 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx @@ -62,7 +62,7 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => { /> - + {isValueChanged && ( { /> )} } + label={} openDelay={HANDLE_TOOLTIP_OPEN_DELAY} placement="top" > diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputField.tsx index 48c4c0d7404..f2d776a2da1 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputField.tsx @@ -1,6 +1,5 @@ import { Flex, FormControl, FormLabel, Tooltip } from '@invoke-ai/ui-library'; import { useConnectionState } from 'features/nodes/hooks/useConnectionState'; -import { useFieldOutputInstance } from 'features/nodes/hooks/useFieldOutputInstance'; import { useFieldOutputTemplate } from 'features/nodes/hooks/useFieldOutputTemplate'; import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants'; import type { PropsWithChildren } from 'react'; @@ -18,18 +17,17 @@ interface Props { const OutputField = ({ nodeId, fieldName }: Props) => { const { t } = useTranslation(); const fieldTemplate = useFieldOutputTemplate(nodeId, fieldName); - const fieldInstance = useFieldOutputInstance(nodeId, fieldName); const { isConnected, isConnectionInProgress, isConnectionStartField, connectionError, shouldDim } = - useConnectionState({ nodeId, fieldName, kind: 'output' }); + useConnectionState({ nodeId, fieldName, kind: 'outputs' }); - if (!fieldTemplate || !fieldInstance) { + if (!fieldTemplate) { return ( {t('nodes.unknownOutput', { - name: fieldTemplate?.title ?? fieldName, + name: fieldName, })} @@ -40,7 +38,7 @@ const OutputField = ({ nodeId, fieldName }: Props) => { return ( } + label={} openDelay={HANDLE_TOOLTIP_OPEN_DELAY} placement="top" shouldWrapChildren diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx index b7c9033d6b2..d72d2f5aa8d 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx @@ -6,19 +6,18 @@ import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableCon import NotesTextarea from 'features/nodes/components/flow/nodes/Invocation/NotesTextarea'; import { useNodeNeedsUpdate } from 'features/nodes/hooks/useNodeNeedsUpdate'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; import EditableNodeTitle from './details/EditableNodeTitle'; -const selector = createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { +const selector = createMemoizedSelector(selectNodesSlice, (nodes) => { const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1]; const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId); - const lastSelectedNodeTemplate = lastSelectedNode ? nodeTemplates.templates[lastSelectedNode.data.type] : undefined; + const lastSelectedNodeTemplate = lastSelectedNode ? nodes.templates[lastSelectedNode.data.type] : undefined; if (!isInvocationNode(lastSelectedNode) || !lastSelectedNodeTemplate) { return; diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx index ee7dfaa6932..978eeddd24a 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx @@ -5,7 +5,6 @@ import { IAINoContentFallback } from 'common/components/IAIImageFallback'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -14,12 +13,12 @@ import type { AnyResult } from 'services/events/types'; import ImageOutputPreview from './outputs/ImageOutputPreview'; -const selector = createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { +const selector = createMemoizedSelector(selectNodesSlice, (nodes) => { const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1]; const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId); - const lastSelectedNodeTemplate = lastSelectedNode ? nodeTemplates.templates[lastSelectedNode.data.type] : undefined; + const lastSelectedNodeTemplate = lastSelectedNode ? nodes.templates[lastSelectedNode.data.type] : undefined; const nes = nodes.nodeExecutionStates[lastSelectedNodeId ?? '__UNKNOWN_NODE__']; diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorTemplateTab.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorTemplateTab.tsx index 28f0e82d68c..ea6e8ed704d 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorTemplateTab.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorTemplateTab.tsx @@ -3,16 +3,15 @@ import { useAppSelector } from 'app/store/storeHooks'; import { IAINoContentFallback } from 'common/components/IAIImageFallback'; import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; -const selector = createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { +const selector = createMemoizedSelector(selectNodesSlice, (nodes) => { const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1]; const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId); - const lastSelectedNodeTemplate = lastSelectedNode ? nodeTemplates.templates[lastSelectedNode.data.type] : undefined; + const lastSelectedNodeTemplate = lastSelectedNode ? nodes.templates[lastSelectedNode.data.type] : undefined; return { template: lastSelectedNodeTemplate, diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/viewMode/WorkflowField.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/viewMode/WorkflowField.tsx index 0e5857933a7..e707dd4f54d 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/viewMode/WorkflowField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/viewMode/WorkflowField.tsx @@ -16,7 +16,7 @@ type Props = { const WorkflowField = ({ nodeId, fieldName }: Props) => { const label = useFieldLabel(nodeId, fieldName); - const fieldTemplateTitle = useFieldTemplateTitle(nodeId, fieldName, 'input'); + const fieldTemplateTitle = useFieldTemplateTitle(nodeId, fieldName, 'inputs'); const { isValueChanged, onReset } = useFieldOriginalValue(nodeId, fieldName); return ( @@ -36,7 +36,7 @@ const WorkflowField = ({ nodeId, fieldName }: Props) => { /> )} } + label={} openDelay={HANDLE_TOOLTIP_OPEN_DELAY} placement="top" > diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts index d0263a8bdaf..c882924e241 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts @@ -1,26 +1,22 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; +import { EMPTY_ARRAY } from 'app/store/util'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeTemplate } from 'features/nodes/store/selectors'; import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames'; import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate'; import { keys, map } from 'lodash-es'; import { useMemo } from 'react'; -export const useAnyOrDirectInputFieldNames = (nodeId: string) => { +export const useAnyOrDirectInputFieldNames = (nodeId: string): string[] => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return []; + createMemoizedSelector(selectNodesSlice, (nodes) => { + const template = selectNodeTemplate(nodes, nodeId); + if (!template) { + return EMPTY_ARRAY; } - const nodeTemplate = nodeTemplates.templates[node.data.type]; - if (!nodeTemplate) { - return []; - } - const fields = map(nodeTemplate.inputs).filter( + const fields = map(template.inputs).filter( (field) => (['any', 'direct'].includes(field.input) || field.type.isCollectionOrScalar) && keys(TEMPLATE_BUILDER_MAP).includes(field.type.name) diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts b/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts index aecc9318938..b19edf3c85a 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts @@ -13,7 +13,7 @@ export const SHARED_NODE_PROPERTIES: Partial = { }; export const useBuildNode = () => { - const nodeTemplates = useAppSelector((s) => s.nodeTemplates.templates); + const nodeTemplates = useAppSelector((s) => s.nodes.templates); const flow = useReactFlow(); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts index 23f318517b5..dc8a05b88c2 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts @@ -1,28 +1,24 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; +import { EMPTY_ARRAY } from 'app/store/util'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeTemplate } from 'features/nodes/store/selectors'; import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames'; import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate'; import { keys, map } from 'lodash-es'; import { useMemo } from 'react'; -export const useConnectionInputFieldNames = (nodeId: string) => { +export const useConnectionInputFieldNames = (nodeId: string): string[] => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return []; - } - const nodeTemplate = nodeTemplates.templates[node.data.type]; - if (!nodeTemplate) { - return []; + createMemoizedSelector(selectNodesSlice, (nodes) => { + const template = selectNodeTemplate(nodes, nodeId); + if (!template) { + return EMPTY_ARRAY; } // get the visible fields - const fields = map(nodeTemplate.inputs).filter( + const fields = map(template.inputs).filter( (field) => (field.input === 'connection' && !field.type.isCollectionOrScalar) || !keys(TEMPLATE_BUILDER_MAP).includes(field.type.name) diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts index a6f8b663f69..97b96f323ad 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts @@ -14,7 +14,7 @@ const selectIsConnectionInProgress = createSelector( export type UseConnectionStateProps = { nodeId: string; fieldName: string; - kind: 'input' | 'output'; + kind: 'inputs' | 'outputs'; }; export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionStateProps) => { @@ -26,8 +26,8 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta Boolean( nodes.edges.filter((edge) => { return ( - (kind === 'input' ? edge.target : edge.source) === nodeId && - (kind === 'input' ? edge.targetHandle : edge.sourceHandle) === fieldName + (kind === 'inputs' ? edge.target : edge.source) === nodeId && + (kind === 'inputs' ? edge.targetHandle : edge.sourceHandle) === fieldName ); }).length ) @@ -36,7 +36,7 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta ); const selectConnectionError = useMemo( - () => makeConnectionErrorSelector(nodeId, fieldName, kind === 'input' ? 'target' : 'source', fieldType), + () => makeConnectionErrorSelector(nodeId, fieldName, kind === 'inputs' ? 'target' : 'source', fieldType), [nodeId, fieldName, kind, fieldType] ); @@ -46,7 +46,7 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta Boolean( nodes.connectionStartParams?.nodeId === nodeId && nodes.connectionStartParams?.handleId === fieldName && - nodes.connectionStartParams?.handleType === { input: 'target', output: 'source' }[kind] + nodes.connectionStartParams?.handleType === { inputs: 'target', outputs: 'source' }[kind] ) ), [fieldName, kind, nodeId] diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useDoNodeVersionsMatch.ts b/invokeai/frontend/web/src/features/nodes/hooks/useDoNodeVersionsMatch.ts index bfbf0a3b2d3..91994cf7525 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useDoNodeVersionsMatch.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useDoNodeVersionsMatch.ts @@ -2,23 +2,19 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { compareVersions } from 'compare-versions'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeData, selectNodeTemplate } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; -export const useDoNodeVersionsMatch = (nodeId: string) => { +export const useDoNodeVersionsMatch = (nodeId: string): boolean => { const selector = useMemo( () => - createSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { + createSelector(selectNodesSlice, (nodes) => { + const data = selectNodeData(nodes, nodeId); + const template = selectNodeTemplate(nodes, nodeId); + if (!template?.version || !data?.version) { return false; } - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - if (!nodeTemplate?.version || !node.data?.version) { - return false; - } - return compareVersions(nodeTemplate.version, node.data.version) === 0; + return compareVersions(template.version, data.version) === 0; }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useDoesInputHaveValue.ts b/invokeai/frontend/web/src/features/nodes/hooks/useDoesInputHaveValue.ts index cfe5c90d9cc..5051eaa55b3 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useDoesInputHaveValue.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useDoesInputHaveValue.ts @@ -1,18 +1,18 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeData } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; -export const useDoesInputHaveValue = (nodeId: string, fieldName: string) => { +export const useDoesInputHaveValue = (nodeId: string, fieldName: string): boolean => { const selector = useMemo( () => createMemoizedSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; + const data = selectNodeData(nodes, nodeId); + if (!data) { + return false; } - return node?.data.inputs[fieldName]?.value !== undefined; + return data.inputs[fieldName]?.value !== undefined; }), [fieldName, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldData.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldData.ts deleted file mode 100644 index 8b35a2d44be..00000000000 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldData.ts +++ /dev/null @@ -1,23 +0,0 @@ -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; -import { useAppSelector } from 'app/store/storeHooks'; -import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; -import { useMemo } from 'react'; - -export const useFieldInstance = (nodeId: string, fieldName: string) => { - const selector = useMemo( - () => - createMemoizedSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; - } - return node?.data.inputs[fieldName]; - }), - [fieldName, nodeId] - ); - - const fieldData = useAppSelector(selector); - - return fieldData; -}; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputInstance.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputInstance.ts index 0793f1f9529..25065e7aba5 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputInstance.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputInstance.ts @@ -1,23 +1,20 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldInputInstance } from 'features/nodes/store/selectors'; +import type { FieldInputInstance } from 'features/nodes/types/field'; import { useMemo } from 'react'; -export const useFieldInputInstance = (nodeId: string, fieldName: string) => { +export const useFieldInputInstance = (nodeId: string, fieldName: string): FieldInputInstance | null => { const selector = useMemo( () => createMemoizedSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; - } - return node.data.inputs[fieldName]; + return selectFieldInputInstance(nodes, nodeId, fieldName); }), [fieldName, nodeId] ); - const fieldTemplate = useAppSelector(selector); + const fieldData = useAppSelector(selector); - return fieldTemplate; + return fieldData; }; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputKind.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputKind.ts index 11d44dbde2e..08de3d9b205 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputKind.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputKind.ts @@ -1,21 +1,16 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldInputTemplate } from 'features/nodes/store/selectors'; +import type { FieldInput } from 'features/nodes/types/field'; import { useMemo } from 'react'; export const useFieldInputKind = (nodeId: string, fieldName: string) => { const selector = useMemo( () => - createSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; - } - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - const fieldTemplate = nodeTemplate?.inputs[fieldName]; - return fieldTemplate?.input; + createSelector(selectNodesSlice, (nodes): FieldInput | null => { + const template = selectFieldInputTemplate(nodes, nodeId, fieldName); + return template?.input ?? null; }), [fieldName, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputTemplate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputTemplate.ts index 8533d2be8df..e8289d7e07d 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputTemplate.ts @@ -1,20 +1,15 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldInputTemplate } from 'features/nodes/store/selectors'; +import type { FieldInputTemplate } from 'features/nodes/types/field'; import { useMemo } from 'react'; -export const useFieldInputTemplate = (nodeId: string, fieldName: string) => { +export const useFieldInputTemplate = (nodeId: string, fieldName: string): FieldInputTemplate | null => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; - } - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - return nodeTemplate?.inputs[fieldName]; + createMemoizedSelector(selectNodesSlice, (nodes) => { + return selectFieldInputTemplate(nodes, nodeId, fieldName); }), [fieldName, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldLabel.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldLabel.ts index ef57956047e..92eab8d1b15 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldLabel.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldLabel.ts @@ -1,18 +1,14 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldInputInstance } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; -export const useFieldLabel = (nodeId: string, fieldName: string) => { +export const useFieldLabel = (nodeId: string, fieldName: string): string | null => { const selector = useMemo( () => createSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; - } - return node?.data.inputs[fieldName]?.label; + return selectFieldInputInstance(nodes, nodeId, fieldName)?.label ?? null; }), [fieldName, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputInstance.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputInstance.ts deleted file mode 100644 index 8b71f1ea014..00000000000 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputInstance.ts +++ /dev/null @@ -1,23 +0,0 @@ -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; -import { useAppSelector } from 'app/store/storeHooks'; -import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; -import { useMemo } from 'react'; - -export const useFieldOutputInstance = (nodeId: string, fieldName: string) => { - const selector = useMemo( - () => - createMemoizedSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; - } - return node.data.outputs[fieldName]; - }), - [fieldName, nodeId] - ); - - const fieldTemplate = useAppSelector(selector); - - return fieldTemplate; -}; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputTemplate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputTemplate.ts index 11f592b399e..cb154071e97 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputTemplate.ts @@ -1,20 +1,15 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldOutputTemplate } from 'features/nodes/store/selectors'; +import type { FieldOutputTemplate } from 'features/nodes/types/field'; import { useMemo } from 'react'; -export const useFieldOutputTemplate = (nodeId: string, fieldName: string) => { +export const useFieldOutputTemplate = (nodeId: string, fieldName: string): FieldOutputTemplate | null => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; - } - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - return nodeTemplate?.outputs[fieldName]; + createMemoizedSelector(selectNodesSlice, (nodes) => { + return selectFieldOutputTemplate(nodes, nodeId, fieldName); }), [fieldName, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplate.ts index 663821da81e..7be4ecfd4df 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplate.ts @@ -1,21 +1,22 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { KIND_MAP } from 'features/nodes/types/constants'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldInputTemplate, selectFieldOutputTemplate } from 'features/nodes/store/selectors'; +import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field'; import { useMemo } from 'react'; -export const useFieldTemplate = (nodeId: string, fieldName: string, kind: 'input' | 'output') => { +export const useFieldTemplate = ( + nodeId: string, + fieldName: string, + kind: 'inputs' | 'outputs' +): FieldInputTemplate | FieldOutputTemplate | null => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; + createMemoizedSelector(selectNodesSlice, (nodes) => { + if (kind === 'inputs') { + return selectFieldInputTemplate(nodes, nodeId, fieldName); } - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - return nodeTemplate?.[KIND_MAP[kind]][fieldName]; + return selectFieldOutputTemplate(nodes, nodeId, fieldName); }), [fieldName, kind, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplateTitle.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplateTitle.ts index cfdcda6efab..e41e0195724 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplateTitle.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplateTitle.ts @@ -1,21 +1,17 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { KIND_MAP } from 'features/nodes/types/constants'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldInputTemplate, selectFieldOutputTemplate } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; -export const useFieldTemplateTitle = (nodeId: string, fieldName: string, kind: 'input' | 'output') => { +export const useFieldTemplateTitle = (nodeId: string, fieldName: string, kind: 'inputs' | 'outputs'): string | null => { const selector = useMemo( () => - createSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; + createSelector(selectNodesSlice, (nodes) => { + if (kind === 'inputs') { + return selectFieldInputTemplate(nodes, nodeId, fieldName)?.title ?? null; } - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - return nodeTemplate?.[KIND_MAP[kind]][fieldName]?.title; + return selectFieldOutputTemplate(nodes, nodeId, fieldName)?.title ?? null; }), [fieldName, kind, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldType.ts.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldType.ts.ts index a834726a136..a71a4d044ee 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldType.ts.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldType.ts.ts @@ -1,20 +1,18 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { KIND_MAP } from 'features/nodes/types/constants'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldInputTemplate, selectFieldOutputTemplate } from 'features/nodes/store/selectors'; +import type { FieldType } from 'features/nodes/types/field'; import { useMemo } from 'react'; -export const useFieldType = (nodeId: string, fieldName: string, kind: 'input' | 'output') => { +export const useFieldType = (nodeId: string, fieldName: string, kind: 'inputs' | 'outputs'): FieldType | null => { const selector = useMemo( () => createMemoizedSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; + if (kind === 'inputs') { + return selectFieldInputTemplate(nodes, nodeId, fieldName)?.type ?? null; } - const field = node.data[KIND_MAP[kind]][fieldName]; - return field?.type; + return selectFieldOutputTemplate(nodes, nodeId, fieldName)?.type ?? null; }), [fieldName, kind, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts index a8019c92d6d..71344197d54 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts @@ -1,13 +1,12 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate'; -const selector = createSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => +const selector = createSelector(selectNodesSlice, (nodes) => nodes.nodes.filter(isInvocationNode).some((node) => { - const template = nodeTemplates.templates[node.data.type]; + const template = nodes.templates[node.data.type]; if (!template) { return false; } diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useHasImageOutput.ts b/invokeai/frontend/web/src/features/nodes/hooks/useHasImageOutput.ts index 617e713c7cc..3ac3cabb220 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useHasImageOutput.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useHasImageOutput.ts @@ -1,24 +1,21 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeTemplate } from 'features/nodes/store/selectors'; import { some } from 'lodash-es'; import { useMemo } from 'react'; -export const useHasImageOutput = (nodeId: string) => { +export const useHasImageOutput = (nodeId: string): boolean => { const selector = useMemo( () => createMemoizedSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return false; - } + const template = selectNodeTemplate(nodes, nodeId); return some( - node.data.outputs, + template?.outputs, (output) => output.type.name === 'ImageField' && // the image primitive node (node type "image") does not actually save the image, do not show the image-saving checkboxes - node.data.type !== 'image' + template?.type !== 'image' ); }), [nodeId] diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useIsIntermediate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useIsIntermediate.ts index 729bfa0cea0..3fad0a2a861 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useIsIntermediate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useIsIntermediate.ts @@ -1,18 +1,14 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeData } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; -export const useIsIntermediate = (nodeId: string) => { +export const useIsIntermediate = (nodeId: string): boolean => { const selector = useMemo( () => createSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return false; - } - return node.data.isIntermediate; + return selectNodeData(nodes, nodeId)?.isIntermediate ?? false; }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts index 39a8abbe7a2..ded05c7b9bf 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts @@ -1,11 +1,10 @@ // TODO: enable this at some point -import { useAppSelector } from 'app/store/storeHooks'; +import { useAppSelector, useAppStore } from 'app/store/storeHooks'; import { getIsGraphAcyclic } from 'features/nodes/store/util/getIsGraphAcyclic'; import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes'; import type { InvocationNodeData } from 'features/nodes/types/invocation'; import { useCallback } from 'react'; import type { Connection, Node } from 'reactflow'; -import { useReactFlow } from 'reactflow'; /** * NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts` @@ -13,36 +12,31 @@ import { useReactFlow } from 'reactflow'; */ export const useIsValidConnection = () => { - const flow = useReactFlow(); + const store = useAppStore(); const shouldValidateGraph = useAppSelector((s) => s.nodes.shouldValidateGraph); const isValidConnection = useCallback( ({ source, sourceHandle, target, targetHandle }: Connection): boolean => { - const edges = flow.getEdges(); - const nodes = flow.getNodes(); // Connection must have valid targets if (!(source && sourceHandle && target && targetHandle)) { return false; } - // Find the source and target nodes - const sourceNode = flow.getNode(source) as Node; - const targetNode = flow.getNode(target) as Node; - - // Conditional guards against undefined nodes/handles - if (!(sourceNode && targetNode && sourceNode.data && targetNode.data)) { + if (source === target) { + // Don't allow nodes to connect to themselves, even if validation is disabled return false; } - const sourceField = sourceNode.data.outputs[sourceHandle]; - const targetField = targetNode.data.inputs[targetHandle]; + const state = store.getState(); + const { nodes, edges, templates } = state.nodes; - if (!sourceField || !targetField) { - // something has gone terribly awry - return false; - } + // Find the source and target nodes + const sourceNode = nodes.find((node) => node.id === source) as Node; + const targetNode = nodes.find((node) => node.id === target) as Node; + const sourceFieldTemplate = templates[sourceNode.data.type]?.outputs[sourceHandle]; + const targetFieldTemplate = templates[targetNode.data.type]?.inputs[targetHandle]; - if (source === target) { - // Don't allow nodes to connect to themselves, even if validation is disabled + // Conditional guards against undefined nodes/handles + if (!(sourceFieldTemplate && targetFieldTemplate)) { return false; } @@ -69,20 +63,20 @@ export const useIsValidConnection = () => { return edge.target === target && edge.targetHandle === targetHandle; }) && // except CollectionItem inputs can have multiples - targetField.type.name !== 'CollectionItemField' + targetFieldTemplate.type.name !== 'CollectionItemField' ) { return false; } // Must use the originalType here if it exists - if (!validateSourceAndTargetTypes(sourceField.type, targetField.type)) { + if (!validateSourceAndTargetTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) { return false; } // Graphs much be acyclic (no loops!) return getIsGraphAcyclic(source, target, nodes, edges); }, - [flow, shouldValidateGraph] + [shouldValidateGraph, store] ); return isValidConnection; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeClassification.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeClassification.ts index c61721030eb..bab8ff3f194 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeClassification.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeClassification.ts @@ -1,20 +1,15 @@ -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; +import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeTemplate } from 'features/nodes/store/selectors'; +import type { Classification } from 'features/nodes/types/common'; import { useMemo } from 'react'; -export const useNodeClassification = (nodeId: string) => { +export const useNodeClassification = (nodeId: string): Classification | null => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return false; - } - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - return nodeTemplate?.classification; + createSelector(selectNodesSlice, (nodes) => { + return selectNodeTemplate(nodes, nodeId)?.classification ?? null; }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeData.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeData.ts index c507def5ee3..fa21008ff8b 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeData.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeData.ts @@ -1,14 +1,15 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; +import { selectNodeData } from 'features/nodes/store/selectors'; +import type { InvocationNodeData } from 'features/nodes/types/invocation'; import { useMemo } from 'react'; -export const useNodeData = (nodeId: string) => { +export const useNodeData = (nodeId: string): InvocationNodeData | null => { const selector = useMemo( () => createMemoizedSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - return node?.data; + return selectNodeData(nodes, nodeId); }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeLabel.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeLabel.ts index c5fc43742a1..31dcb9c466e 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeLabel.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeLabel.ts @@ -1,19 +1,14 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeData } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; export const useNodeLabel = (nodeId: string) => { const selector = useMemo( () => createSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return false; - } - - return node.data.label; + return selectNodeData(nodes, nodeId)?.label ?? null; }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeNeedsUpdate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeNeedsUpdate.ts index e6efa667f12..aa0294f70f0 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeNeedsUpdate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeNeedsUpdate.ts @@ -1,21 +1,20 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectInvocationNode, selectNodeTemplate } from 'features/nodes/store/selectors'; import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate'; import { useMemo } from 'react'; export const useNodeNeedsUpdate = (nodeId: string) => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - const template = nodeTemplates.templates[node?.data.type ?? '']; - if (isInvocationNode(node) && template) { - return getNeedsUpdate(node, template); + createMemoizedSelector(selectNodesSlice, (nodes) => { + const node = selectInvocationNode(nodes, nodeId); + const template = selectNodeTemplate(nodes, nodeId); + if (!node || !template) { + return false; } - return false; + return getNeedsUpdate(node, template); }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodePack.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodePack.ts index ca3dd5cfdf6..5c920866e9d 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodePack.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodePack.ts @@ -1,18 +1,14 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeData } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; -export const useNodePack = (nodeId: string) => { +export const useNodePack = (nodeId: string): string | null => { const selector = useMemo( () => createSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return false; - } - return node.data.nodePack; + return selectNodeData(nodes, nodeId)?.nodePack ?? null; }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplate.ts index 7544cbff461..866c9275fb3 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplate.ts @@ -1,16 +1,15 @@ -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; +import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; +import { selectNodeTemplate } from 'features/nodes/store/selectors'; +import type { InvocationTemplate } from 'features/nodes/types/invocation'; import { useMemo } from 'react'; -export const useNodeTemplate = (nodeId: string) => { +export const useNodeTemplate = (nodeId: string): InvocationTemplate | null => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - return nodeTemplate; + createSelector(selectNodesSlice, (nodes) => { + return selectNodeTemplate(nodes, nodeId); }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateByType.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateByType.ts index 8fd1345f6f5..a0c870f6941 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateByType.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateByType.ts @@ -1,14 +1,14 @@ -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; +import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; +import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; import type { InvocationTemplate } from 'features/nodes/types/invocation'; import { useMemo } from 'react'; -export const useNodeTemplateByType = (type: string) => { +export const useNodeTemplateByType = (type: string): InvocationTemplate | null => { const selector = useMemo( () => - createMemoizedSelector(selectNodeTemplatesSlice, (nodeTemplates): InvocationTemplate | undefined => { - return nodeTemplates.templates[type]; + createSelector(selectNodesSlice, (nodes) => { + return nodes.templates[type] ?? null; }), [type] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts index 15d2ec38c32..120b8c758be 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts @@ -1,21 +1,14 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeTemplate } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; -export const useNodeTemplateTitle = (nodeId: string) => { +export const useNodeTemplateTitle = (nodeId: string): string | null => { const selector = useMemo( () => - createSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return false; - } - const nodeTemplate = node ? nodeTemplates.templates[node.data.type] : undefined; - - return nodeTemplate?.title; + createSelector(selectNodesSlice, (nodes) => { + return selectNodeTemplate(nodes, nodeId)?.title ?? null; }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts index e352bd8b90f..24863080a74 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts @@ -1,8 +1,8 @@ -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; +import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; +import { EMPTY_ARRAY } from 'app/store/util'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeTemplate } from 'features/nodes/store/selectors'; import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames'; import { map } from 'lodash-es'; import { useMemo } from 'react'; @@ -10,17 +10,13 @@ import { useMemo } from 'react'; export const useOutputFieldNames = (nodeId: string) => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return []; - } - const nodeTemplate = nodeTemplates.templates[node.data.type]; - if (!nodeTemplate) { - return []; + createSelector(selectNodesSlice, (nodes) => { + const template = selectNodeTemplate(nodes, nodeId); + if (!template) { + return EMPTY_ARRAY; } - return getSortedFilteredFieldNames(map(nodeTemplate.outputs)); + return getSortedFilteredFieldNames(map(template.outputs)); }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useUseCache.ts b/invokeai/frontend/web/src/features/nodes/hooks/useUseCache.ts index edfc990882b..aaca80039b0 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useUseCache.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useUseCache.ts @@ -1,18 +1,14 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeData } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; export const useUseCache = (nodeId: string) => { const selector = useMemo( () => createSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return false; - } - return node.data.useCache; + return selectNodeData(nodes, nodeId)?.useCache ?? false; }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useWorkflowWatcher.ts b/invokeai/frontend/web/src/features/nodes/hooks/useWorkflowWatcher.ts index 0e4806d81b7..5d79c154428 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useWorkflowWatcher.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useWorkflowWatcher.ts @@ -2,14 +2,14 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; import { selectWorkflowSlice } from 'features/nodes/store/workflowSlice'; -import type { WorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowV3 } from 'features/nodes/types/workflow'; import type { BuildWorkflowArg } from 'features/nodes/util/workflow/buildWorkflow'; import { buildWorkflowFast } from 'features/nodes/util/workflow/buildWorkflow'; import { debounce } from 'lodash-es'; import { atom } from 'nanostores'; import { useEffect } from 'react'; -export const $builtWorkflow = atom(null); +export const $builtWorkflow = atom(null); const debouncedBuildWorkflow = debounce((arg: BuildWorkflowArg) => { $builtWorkflow.set(buildWorkflowFast(arg)); diff --git a/invokeai/frontend/web/src/features/nodes/store/actions.ts b/invokeai/frontend/web/src/features/nodes/store/actions.ts index 00457494bfb..b32a3ba9979 100644 --- a/invokeai/frontend/web/src/features/nodes/store/actions.ts +++ b/invokeai/frontend/web/src/features/nodes/store/actions.ts @@ -1,5 +1,5 @@ import { createAction, isAnyOf } from '@reduxjs/toolkit'; -import type { WorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowV3 } from 'features/nodes/types/workflow'; import type { Graph } from 'services/api/types'; export const textToImageGraphBuilt = createAction('nodes/textToImageGraphBuilt'); @@ -21,4 +21,4 @@ export const workflowLoadRequested = createAction<{ export const updateAllNodesRequested = createAction('nodes/updateAllNodesRequested'); -export const workflowLoaded = createAction('workflow/workflowLoaded'); +export const workflowLoaded = createAction('workflow/workflowLoaded'); diff --git a/invokeai/frontend/web/src/features/nodes/store/nodeTemplatesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodeTemplatesSlice.ts deleted file mode 100644 index c211131aab7..00000000000 --- a/invokeai/frontend/web/src/features/nodes/store/nodeTemplatesSlice.ts +++ /dev/null @@ -1,24 +0,0 @@ -import type { PayloadAction } from '@reduxjs/toolkit'; -import { createSlice } from '@reduxjs/toolkit'; -import type { RootState } from 'app/store/store'; -import type { InvocationTemplate } from 'features/nodes/types/invocation'; - -import type { NodeTemplatesState } from './types'; - -export const initialNodeTemplatesState: NodeTemplatesState = { - templates: {}, -}; - -export const nodesTemplatesSlice = createSlice({ - name: 'nodeTemplates', - initialState: initialNodeTemplatesState, - reducers: { - nodeTemplatesBuilt: (state, action: PayloadAction>) => { - state.templates = action.payload; - }, - }, -}); - -export const { nodeTemplatesBuilt } = nodesTemplatesSlice.actions; - -export const selectNodeTemplatesSlice = (state: RootState) => state.nodeTemplates; diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index aee01b381ba..6b596da0633 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -42,7 +42,7 @@ import { zT2IAdapterModelFieldValue, zVAEModelFieldValue, } from 'features/nodes/types/field'; -import type { AnyNode, NodeExecutionState } from 'features/nodes/types/invocation'; +import type { AnyNode, InvocationTemplate, NodeExecutionState } from 'features/nodes/types/invocation'; import { isInvocationNode, isNotesNode, zNodeStatus } from 'features/nodes/types/invocation'; import { cloneDeep, forEach } from 'lodash-es'; import type { @@ -92,6 +92,7 @@ export const initialNodesState: NodesState = { _version: 1, nodes: [], edges: [], + templates: {}, connectionStartParams: null, connectionStartFieldType: null, connectionMade: false, @@ -190,6 +191,7 @@ export const nodesSlice = createSlice({ node, state.nodes, state.edges, + state.templates, nodeId, handleId, handleType, @@ -224,12 +226,12 @@ export const nodesSlice = createSlice({ if (!nodeId || !handleId) { return; } - const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId); - const node = state.nodes?.[nodeIndex]; + const node = state.nodes.find((n) => n.id === nodeId); if (!isInvocationNode(node)) { return; } - const field = handleType === 'source' ? node.data.outputs[handleId] : node.data.inputs[handleId]; + const template = state.templates[node.data.type]; + const field = handleType === 'source' ? template?.outputs[handleId] : template?.inputs[handleId]; state.connectionStartFieldType = field?.type ?? null; }, connectionMade: (state, action: PayloadAction) => { @@ -260,6 +262,7 @@ export const nodesSlice = createSlice({ mouseOverNode, state.nodes, state.edges, + state.templates, nodeId, handleId, handleType, @@ -677,6 +680,9 @@ export const nodesSlice = createSlice({ selectionModeChanged: (state, action: PayloadAction) => { state.selectionMode = action.payload ? SelectionMode.Full : SelectionMode.Partial; }, + nodeTemplatesBuilt: (state, action: PayloadAction>) => { + state.templates = action.payload; + }, }, extraReducers: (builder) => { builder.addCase(workflowLoaded, (state, action) => { @@ -808,6 +814,7 @@ export const { shouldValidateGraphChanged, viewportChanged, edgeAdded, + nodeTemplatesBuilt, } = nodesSlice.actions; // This is used for tracking `state.workflow.isTouched` diff --git a/invokeai/frontend/web/src/features/nodes/store/selectors.ts b/invokeai/frontend/web/src/features/nodes/store/selectors.ts new file mode 100644 index 00000000000..90675d62707 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/selectors.ts @@ -0,0 +1,51 @@ +import type { NodesState } from 'features/nodes/store/types'; +import type { FieldInputInstance, FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field'; +import type { InvocationNode, InvocationNodeData, InvocationTemplate } from 'features/nodes/types/invocation'; +import { isInvocationNode } from 'features/nodes/types/invocation'; + +export const selectInvocationNode = (nodesSlice: NodesState, nodeId: string): InvocationNode | null => { + const node = nodesSlice.nodes.find((node) => node.id === nodeId); + if (!isInvocationNode(node)) { + return null; + } + return node; +}; + +export const selectNodeData = (nodesSlice: NodesState, nodeId: string): InvocationNodeData | null => { + return selectInvocationNode(nodesSlice, nodeId)?.data ?? null; +}; + +export const selectNodeTemplate = (nodesSlice: NodesState, nodeId: string): InvocationTemplate | null => { + const node = selectInvocationNode(nodesSlice, nodeId); + if (!node) { + return null; + } + return nodesSlice.templates[node.data.type] ?? null; +}; + +export const selectFieldInputInstance = ( + nodesSlice: NodesState, + nodeId: string, + fieldName: string +): FieldInputInstance | null => { + const data = selectNodeData(nodesSlice, nodeId); + return data?.inputs[fieldName] ?? null; +}; + +export const selectFieldInputTemplate = ( + nodesSlice: NodesState, + nodeId: string, + fieldName: string +): FieldInputTemplate | null => { + const template = selectNodeTemplate(nodesSlice, nodeId); + return template?.inputs[fieldName] ?? null; +}; + +export const selectFieldOutputTemplate = ( + nodesSlice: NodesState, + nodeId: string, + fieldName: string +): FieldOutputTemplate | null => { + const template = selectNodeTemplate(nodesSlice, nodeId); + return template?.outputs[fieldName] ?? null; +}; diff --git a/invokeai/frontend/web/src/features/nodes/store/types.ts b/invokeai/frontend/web/src/features/nodes/store/types.ts index 8b0de447e43..1a040d2c705 100644 --- a/invokeai/frontend/web/src/features/nodes/store/types.ts +++ b/invokeai/frontend/web/src/features/nodes/store/types.ts @@ -5,13 +5,14 @@ import type { InvocationTemplate, NodeExecutionState, } from 'features/nodes/types/invocation'; -import type { WorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowV3 } from 'features/nodes/types/workflow'; import type { OnConnectStartParams, SelectionMode, Viewport, XYPosition } from 'reactflow'; export type NodesState = { _version: 1; nodes: AnyNode[]; edges: InvocationNodeEdge[]; + templates: Record; connectionStartParams: OnConnectStartParams | null; connectionStartFieldType: FieldType | null; connectionMade: boolean; @@ -38,7 +39,7 @@ export type FieldIdentifierWithValue = FieldIdentifier & { value: StatefulFieldValue; }; -export type WorkflowsState = Omit & { +export type WorkflowsState = Omit & { _version: 1; isTouched: boolean; mode: WorkflowMode; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts b/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts index 9f2c37a2ad7..ef899c5f414 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts @@ -1,4 +1,6 @@ -import type { FieldInputInstance, FieldOutputInstance, FieldType } from 'features/nodes/types/field'; +import type { FieldInputTemplate, FieldOutputTemplate, FieldType } from 'features/nodes/types/field'; +import type { AnyNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation'; +import { isInvocationNode } from 'features/nodes/types/invocation'; import type { Connection, Edge, HandleType, Node } from 'reactflow'; import { getIsGraphAcyclic } from './getIsGraphAcyclic'; @@ -9,7 +11,7 @@ const isValidConnection = ( handleCurrentType: HandleType, handleCurrentFieldType: FieldType, node: Node, - handle: FieldInputInstance | FieldOutputInstance + handle: FieldInputTemplate | FieldOutputTemplate ) => { let isValidConnection = true; if (handleCurrentType === 'source') { @@ -38,24 +40,31 @@ const isValidConnection = ( }; export const findConnectionToValidHandle = ( - node: Node, - nodes: Node[], - edges: Edge[], + node: AnyNode, + nodes: AnyNode[], + edges: InvocationNodeEdge[], + templates: Record, handleCurrentNodeId: string, handleCurrentName: string, handleCurrentType: HandleType, handleCurrentFieldType: FieldType ): Connection | null => { - if (node.id === handleCurrentNodeId) { + if (node.id === handleCurrentNodeId || !isInvocationNode(node)) { return null; } - const handles = handleCurrentType === 'source' ? node.data.inputs : node.data.outputs; + const template = templates[node.data.type]; + + if (!template) { + return null; + } + + const handles = handleCurrentType === 'source' ? template.inputs : template.outputs; //Prioritize handles whos name matches the node we're coming from - if (handles[handleCurrentName]) { - const handle = handles[handleCurrentName]; + const handle = handles[handleCurrentName]; + if (handle) { const sourceID = handleCurrentType === 'source' ? handleCurrentNodeId : node.id; const targetID = handleCurrentType === 'source' ? node.id : handleCurrentNodeId; const sourceHandle = handleCurrentType === 'source' ? handleCurrentName : handle.name; @@ -77,6 +86,9 @@ export const findConnectionToValidHandle = ( for (const handleName in handles) { const handle = handles[handleName]; + if (!handle) { + continue; + } const sourceID = handleCurrentType === 'source' ? handleCurrentNodeId : node.id; const targetID = handleCurrentType === 'source' ? node.id : handleCurrentNodeId; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts index 8575932cbdd..d6ea0d9c86e 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts @@ -16,7 +16,7 @@ export const makeConnectionErrorSelector = ( nodeId: string, fieldName: string, handleType: HandleType, - fieldType?: FieldType + fieldType?: FieldType | null ) => { return createSelector(selectNodesSlice, (nodesSlice) => { if (!fieldType) { diff --git a/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts b/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts index 2978f25138d..4f40a68e1f0 100644 --- a/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts @@ -10,10 +10,10 @@ import type { } from 'features/nodes/store/types'; import type { FieldIdentifier } from 'features/nodes/types/field'; import { isInvocationNode } from 'features/nodes/types/invocation'; -import type { WorkflowCategory, WorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowCategory, WorkflowV3 } from 'features/nodes/types/workflow'; import { cloneDeep, isEqual, omit, uniqBy } from 'lodash-es'; -export const blankWorkflow: Omit = { +export const blankWorkflow: Omit = { name: '', author: '', description: '', @@ -22,7 +22,7 @@ export const blankWorkflow: Omit = { tags: '', notes: '', exposedFields: [], - meta: { version: '2.0.0', category: 'user' }, + meta: { version: '3.0.0', category: 'user' }, id: undefined, }; diff --git a/invokeai/frontend/web/src/features/nodes/types/common.test-d.ts b/invokeai/frontend/web/src/features/nodes/types/common.test-d.ts new file mode 100644 index 00000000000..7f28e864a13 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/common.test-d.ts @@ -0,0 +1,69 @@ +import type { + BaseModel, + BoardField, + Classification, + CLIPField, + ColorField, + ControlField, + ControlNetModelField, + ImageField, + ImageOutput, + IPAdapterField, + IPAdapterModelField, + LoraInfo, + LoRAModelField, + MainModelField, + ModelInfo, + ModelType, + ProgressImage, + SchedulerField, + SDXLRefinerModelField, + SubModelType, + T2IAdapterField, + T2IAdapterModelField, + UNetField, + VAEField, +} from 'features/nodes/types/common'; +import type { S } from 'services/api/types'; +import type { Equals, Extends } from 'tsafe'; +import { assert } from 'tsafe'; +import { describe, test } from 'vitest'; + +/** + * These types originate from the server and are recreated as zod schemas manually, for use at runtime. + * The tests ensure that the types are correctly recreated. + */ + +describe('Common types', () => { + // Complex field types + test('ImageField', () => assert>()); + test('BoardField', () => assert>()); + test('ColorField', () => assert>()); + test('SchedulerField', () => assert>>()); + test('UNetField', () => assert>()); + test('CLIPField', () => assert>()); + test('MainModelField', () => assert>()); + test('SDXLRefinerModelField', () => assert>()); + test('VAEField', () => assert>()); + test('ControlField', () => assert>()); + // @ts-expect-error TODO(psyche): fix types + test('IPAdapterField', () => assert>()); + test('T2IAdapterField', () => assert>()); + test('LoRAModelField', () => assert>()); + test('ControlNetModelField', () => assert>()); + test('IPAdapterModelField', () => assert>()); + test('T2IAdapterModelField', () => assert>()); + + // Model component types + test('BaseModel', () => assert>()); + test('ModelType', () => assert>()); + test('SubModelType', () => assert>()); + test('ModelInfo', () => assert>()); + + // Misc types + test('LoraInfo', () => assert>()); + // @ts-expect-error TODO(psyche): There is no `ProgressImage` in the server types yet + test('ProgressImage', () => assert>()); + test('ImageOutput', () => assert>()); + test('Classification', () => assert>()); +}); diff --git a/invokeai/frontend/web/src/features/nodes/types/common.ts b/invokeai/frontend/web/src/features/nodes/types/common.ts index b5244743799..ef579fce8cf 100644 --- a/invokeai/frontend/web/src/features/nodes/types/common.ts +++ b/invokeai/frontend/web/src/features/nodes/types/common.ts @@ -52,27 +52,29 @@ export type SchedulerField = z.infer; // #region Model-related schemas export const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']); -export const zModelType = z.enum(['main', 'vae', 'lora', 'controlnet', 'embedding']); +export const zModelType = z.enum([ + 'main', + 'vae', + 'lora', + 'controlnet', + 'embedding', + 'ip_adapter', + 'clip_vision', + 't2i_adapter', + 'onnx', // TODO(psyche): Remove this when removed from backend +]); export const zModelName = z.string().min(3); export const zModelIdentifier = z.object({ - model_name: zModelName, - base_model: zBaseModel, + key: z.string().min(1), }); export type BaseModel = z.infer; export type ModelType = z.infer; export type ModelIdentifier = z.infer; -export const zMainModelField = z.object({ - model_name: zModelName, - base_model: zBaseModel, - model_type: z.literal('main'), -}); -export const zSDXLRefinerModelField = z.object({ - model_name: z.string().min(1), - base_model: z.literal('sdxl-refiner'), - model_type: z.literal('main'), -}); +export const zMainModelField = zModelIdentifier; export type MainModelField = z.infer; + +export const zSDXLRefinerModelField = zModelIdentifier; export type SDXLRefinerModelField = z.infer; export const zSubModelType = z.enum([ @@ -92,8 +94,7 @@ export type SubModelType = z.infer; export const zVAEModelField = zModelIdentifier; export const zModelInfo = zModelIdentifier.extend({ - model_type: zModelType, - submodel: zSubModelType.optional(), + submodel_type: zSubModelType.nullish(), }); export type ModelInfo = z.infer; diff --git a/invokeai/frontend/web/src/features/nodes/types/error.ts b/invokeai/frontend/web/src/features/nodes/types/error.ts index 905b487fb04..82bc0f86e09 100644 --- a/invokeai/frontend/web/src/features/nodes/types/error.ts +++ b/invokeai/frontend/web/src/features/nodes/types/error.ts @@ -56,3 +56,8 @@ export class FieldParseError extends Error { this.name = this.constructor.name; } } + +export class UnableToExtractSchemaNameFromRefError extends FieldParseError {} +export class UnsupportedArrayItemType extends FieldParseError {} +export class UnsupportedUnionError extends FieldParseError {} +export class UnsupportedPrimitiveTypeError extends FieldParseError {} diff --git a/invokeai/frontend/web/src/features/nodes/types/field.ts b/invokeai/frontend/web/src/features/nodes/types/field.ts index 38f1af55dd8..aa6164d6e53 100644 --- a/invokeai/frontend/web/src/features/nodes/types/field.ts +++ b/invokeai/frontend/web/src/features/nodes/types/field.ts @@ -46,20 +46,11 @@ export type FieldInput = z.infer; export const zFieldUIComponent = z.enum(['none', 'textarea', 'slider']); export type FieldUIComponent = z.infer; -export const zFieldInstanceBase = z.object({ - id: z.string().trim().min(1), +export const zFieldInputInstanceBase = z.object({ name: z.string().trim().min(1), -}); -export const zFieldInputInstanceBase = zFieldInstanceBase.extend({ - fieldKind: z.literal('input'), label: z.string().nullish(), }); -export const zFieldOutputInstanceBase = zFieldInstanceBase.extend({ - fieldKind: z.literal('output'), -}); -export type FieldInstanceBase = z.infer; export type FieldInputInstanceBase = z.infer; -export type FieldOutputInstanceBase = z.infer; export const zFieldTemplateBase = z.object({ name: z.string().min(1), @@ -102,12 +93,8 @@ export const zIntegerFieldType = zFieldTypeBase.extend({ }); export const zIntegerFieldValue = z.number().int(); export const zIntegerFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zIntegerFieldType, value: zIntegerFieldValue, }); -export const zIntegerFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zIntegerFieldType, -}); export const zIntegerFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zIntegerFieldType, default: zIntegerFieldValue, @@ -136,12 +123,8 @@ export const zFloatFieldType = zFieldTypeBase.extend({ }); export const zFloatFieldValue = z.number(); export const zFloatFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zFloatFieldType, value: zFloatFieldValue, }); -export const zFloatFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zFloatFieldType, -}); export const zFloatFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zFloatFieldType, default: zFloatFieldValue, @@ -157,7 +140,6 @@ export const zFloatFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type FloatFieldType = z.infer; export type FloatFieldValue = z.infer; export type FloatFieldInputInstance = z.infer; -export type FloatFieldOutputInstance = z.infer; export type FloatFieldInputTemplate = z.infer; export type FloatFieldOutputTemplate = z.infer; export const isFloatFieldInputInstance = (val: unknown): val is FloatFieldInputInstance => @@ -172,12 +154,8 @@ export const zStringFieldType = zFieldTypeBase.extend({ }); export const zStringFieldValue = z.string(); export const zStringFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zStringFieldType, value: zStringFieldValue, }); -export const zStringFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zStringFieldType, -}); export const zStringFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zStringFieldType, default: zStringFieldValue, @@ -191,7 +169,6 @@ export const zStringFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type StringFieldType = z.infer; export type StringFieldValue = z.infer; export type StringFieldInputInstance = z.infer; -export type StringFieldOutputInstance = z.infer; export type StringFieldInputTemplate = z.infer; export type StringFieldOutputTemplate = z.infer; export const isStringFieldInputInstance = (val: unknown): val is StringFieldInputInstance => @@ -206,12 +183,8 @@ export const zBooleanFieldType = zFieldTypeBase.extend({ }); export const zBooleanFieldValue = z.boolean(); export const zBooleanFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zBooleanFieldType, value: zBooleanFieldValue, }); -export const zBooleanFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zBooleanFieldType, -}); export const zBooleanFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zBooleanFieldType, default: zBooleanFieldValue, @@ -222,7 +195,6 @@ export const zBooleanFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type BooleanFieldType = z.infer; export type BooleanFieldValue = z.infer; export type BooleanFieldInputInstance = z.infer; -export type BooleanFieldOutputInstance = z.infer; export type BooleanFieldInputTemplate = z.infer; export type BooleanFieldOutputTemplate = z.infer; export const isBooleanFieldInputInstance = (val: unknown): val is BooleanFieldInputInstance => @@ -237,12 +209,8 @@ export const zEnumFieldType = zFieldTypeBase.extend({ }); export const zEnumFieldValue = z.string(); export const zEnumFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zEnumFieldType, value: zEnumFieldValue, }); -export const zEnumFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zEnumFieldType, -}); export const zEnumFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zEnumFieldType, default: zEnumFieldValue, @@ -255,7 +223,6 @@ export const zEnumFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type EnumFieldType = z.infer; export type EnumFieldValue = z.infer; export type EnumFieldInputInstance = z.infer; -export type EnumFieldOutputInstance = z.infer; export type EnumFieldInputTemplate = z.infer; export type EnumFieldOutputTemplate = z.infer; export const isEnumFieldInputInstance = (val: unknown): val is EnumFieldInputInstance => @@ -270,12 +237,8 @@ export const zImageFieldType = zFieldTypeBase.extend({ }); export const zImageFieldValue = zImageField.optional(); export const zImageFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zImageFieldType, value: zImageFieldValue, }); -export const zImageFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zImageFieldType, -}); export const zImageFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zImageFieldType, default: zImageFieldValue, @@ -286,7 +249,6 @@ export const zImageFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type ImageFieldType = z.infer; export type ImageFieldValue = z.infer; export type ImageFieldInputInstance = z.infer; -export type ImageFieldOutputInstance = z.infer; export type ImageFieldInputTemplate = z.infer; export type ImageFieldOutputTemplate = z.infer; export const isImageFieldInputInstance = (val: unknown): val is ImageFieldInputInstance => @@ -301,12 +263,8 @@ export const zBoardFieldType = zFieldTypeBase.extend({ }); export const zBoardFieldValue = zBoardField.optional(); export const zBoardFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zBoardFieldType, value: zBoardFieldValue, }); -export const zBoardFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zBoardFieldType, -}); export const zBoardFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zBoardFieldType, default: zBoardFieldValue, @@ -317,7 +275,6 @@ export const zBoardFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type BoardFieldType = z.infer; export type BoardFieldValue = z.infer; export type BoardFieldInputInstance = z.infer; -export type BoardFieldOutputInstance = z.infer; export type BoardFieldInputTemplate = z.infer; export type BoardFieldOutputTemplate = z.infer; export const isBoardFieldInputInstance = (val: unknown): val is BoardFieldInputInstance => @@ -332,12 +289,8 @@ export const zColorFieldType = zFieldTypeBase.extend({ }); export const zColorFieldValue = zColorField.optional(); export const zColorFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zColorFieldType, value: zColorFieldValue, }); -export const zColorFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zColorFieldType, -}); export const zColorFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zColorFieldType, default: zColorFieldValue, @@ -348,7 +301,6 @@ export const zColorFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type ColorFieldType = z.infer; export type ColorFieldValue = z.infer; export type ColorFieldInputInstance = z.infer; -export type ColorFieldOutputInstance = z.infer; export type ColorFieldInputTemplate = z.infer; export type ColorFieldOutputTemplate = z.infer; export const isColorFieldInputInstance = (val: unknown): val is ColorFieldInputInstance => @@ -363,12 +315,8 @@ export const zMainModelFieldType = zFieldTypeBase.extend({ }); export const zMainModelFieldValue = zMainModelField.optional(); export const zMainModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zMainModelFieldType, value: zMainModelFieldValue, }); -export const zMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zMainModelFieldType, -}); export const zMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zMainModelFieldType, default: zMainModelFieldValue, @@ -379,7 +327,6 @@ export const zMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type MainModelFieldType = z.infer; export type MainModelFieldValue = z.infer; export type MainModelFieldInputInstance = z.infer; -export type MainModelFieldOutputInstance = z.infer; export type MainModelFieldInputTemplate = z.infer; export type MainModelFieldOutputTemplate = z.infer; export const isMainModelFieldInputInstance = (val: unknown): val is MainModelFieldInputInstance => @@ -394,12 +341,8 @@ export const zSDXLMainModelFieldType = zFieldTypeBase.extend({ }); export const zSDXLMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only. export const zSDXLMainModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zSDXLMainModelFieldType, value: zSDXLMainModelFieldValue, }); -export const zSDXLMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zSDXLMainModelFieldType, -}); export const zSDXLMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zSDXLMainModelFieldType, default: zSDXLMainModelFieldValue, @@ -410,7 +353,6 @@ export const zSDXLMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend export type SDXLMainModelFieldType = z.infer; export type SDXLMainModelFieldValue = z.infer; export type SDXLMainModelFieldInputInstance = z.infer; -export type SDXLMainModelFieldOutputInstance = z.infer; export type SDXLMainModelFieldInputTemplate = z.infer; export type SDXLMainModelFieldOutputTemplate = z.infer; export const isSDXLMainModelFieldInputInstance = (val: unknown): val is SDXLMainModelFieldInputInstance => @@ -425,12 +367,8 @@ export const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({ }); export const zSDXLRefinerModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL Refiner models only. export const zSDXLRefinerModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zSDXLRefinerModelFieldType, value: zSDXLRefinerModelFieldValue, }); -export const zSDXLRefinerModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zSDXLRefinerModelFieldType, -}); export const zSDXLRefinerModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zSDXLRefinerModelFieldType, default: zSDXLRefinerModelFieldValue, @@ -441,7 +379,6 @@ export const zSDXLRefinerModelFieldOutputTemplate = zFieldOutputTemplateBase.ext export type SDXLRefinerModelFieldType = z.infer; export type SDXLRefinerModelFieldValue = z.infer; export type SDXLRefinerModelFieldInputInstance = z.infer; -export type SDXLRefinerModelFieldOutputInstance = z.infer; export type SDXLRefinerModelFieldInputTemplate = z.infer; export type SDXLRefinerModelFieldOutputTemplate = z.infer; export const isSDXLRefinerModelFieldInputInstance = (val: unknown): val is SDXLRefinerModelFieldInputInstance => @@ -456,12 +393,8 @@ export const zVAEModelFieldType = zFieldTypeBase.extend({ }); export const zVAEModelFieldValue = zVAEModelField.optional(); export const zVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zVAEModelFieldType, value: zVAEModelFieldValue, }); -export const zVAEModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zVAEModelFieldType, -}); export const zVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zVAEModelFieldType, default: zVAEModelFieldValue, @@ -472,7 +405,6 @@ export const zVAEModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type VAEModelFieldType = z.infer; export type VAEModelFieldValue = z.infer; export type VAEModelFieldInputInstance = z.infer; -export type VAEModelFieldOutputInstance = z.infer; export type VAEModelFieldInputTemplate = z.infer; export type VAEModelFieldOutputTemplate = z.infer; export const isVAEModelFieldInputInstance = (val: unknown): val is VAEModelFieldInputInstance => @@ -487,12 +419,8 @@ export const zLoRAModelFieldType = zFieldTypeBase.extend({ }); export const zLoRAModelFieldValue = zLoRAModelField.optional(); export const zLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zLoRAModelFieldType, value: zLoRAModelFieldValue, }); -export const zLoRAModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zLoRAModelFieldType, -}); export const zLoRAModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zLoRAModelFieldType, default: zLoRAModelFieldValue, @@ -503,7 +431,6 @@ export const zLoRAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type LoRAModelFieldType = z.infer; export type LoRAModelFieldValue = z.infer; export type LoRAModelFieldInputInstance = z.infer; -export type LoRAModelFieldOutputInstance = z.infer; export type LoRAModelFieldInputTemplate = z.infer; export type LoRAModelFieldOutputTemplate = z.infer; export const isLoRAModelFieldInputInstance = (val: unknown): val is LoRAModelFieldInputInstance => @@ -518,12 +445,8 @@ export const zControlNetModelFieldType = zFieldTypeBase.extend({ }); export const zControlNetModelFieldValue = zControlNetModelField.optional(); export const zControlNetModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zControlNetModelFieldType, value: zControlNetModelFieldValue, }); -export const zControlNetModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zControlNetModelFieldType, -}); export const zControlNetModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zControlNetModelFieldType, default: zControlNetModelFieldValue, @@ -534,7 +457,6 @@ export const zControlNetModelFieldOutputTemplate = zFieldOutputTemplateBase.exte export type ControlNetModelFieldType = z.infer; export type ControlNetModelFieldValue = z.infer; export type ControlNetModelFieldInputInstance = z.infer; -export type ControlNetModelFieldOutputInstance = z.infer; export type ControlNetModelFieldInputTemplate = z.infer; export type ControlNetModelFieldOutputTemplate = z.infer; export const isControlNetModelFieldInputInstance = (val: unknown): val is ControlNetModelFieldInputInstance => @@ -551,12 +473,8 @@ export const zIPAdapterModelFieldType = zFieldTypeBase.extend({ }); export const zIPAdapterModelFieldValue = zIPAdapterModelField.optional(); export const zIPAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zIPAdapterModelFieldType, value: zIPAdapterModelFieldValue, }); -export const zIPAdapterModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zIPAdapterModelFieldType, -}); export const zIPAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zIPAdapterModelFieldType, default: zIPAdapterModelFieldValue, @@ -567,7 +485,6 @@ export const zIPAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.exten export type IPAdapterModelFieldType = z.infer; export type IPAdapterModelFieldValue = z.infer; export type IPAdapterModelFieldInputInstance = z.infer; -export type IPAdapterModelFieldOutputInstance = z.infer; export type IPAdapterModelFieldInputTemplate = z.infer; export type IPAdapterModelFieldOutputTemplate = z.infer; export const isIPAdapterModelFieldInputInstance = (val: unknown): val is IPAdapterModelFieldInputInstance => @@ -584,12 +501,8 @@ export const zT2IAdapterModelFieldType = zFieldTypeBase.extend({ }); export const zT2IAdapterModelFieldValue = zT2IAdapterModelField.optional(); export const zT2IAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zT2IAdapterModelFieldType, value: zT2IAdapterModelFieldValue, }); -export const zT2IAdapterModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zT2IAdapterModelFieldType, -}); export const zT2IAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zT2IAdapterModelFieldType, default: zT2IAdapterModelFieldValue, @@ -600,7 +513,6 @@ export const zT2IAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.exte export type T2IAdapterModelFieldType = z.infer; export type T2IAdapterModelFieldValue = z.infer; export type T2IAdapterModelFieldInputInstance = z.infer; -export type T2IAdapterModelFieldOutputInstance = z.infer; export type T2IAdapterModelFieldInputTemplate = z.infer; export type T2IAdapterModelFieldOutputTemplate = z.infer; export const isT2IAdapterModelFieldInputInstance = (val: unknown): val is T2IAdapterModelFieldInputInstance => @@ -615,12 +527,8 @@ export const zSchedulerFieldType = zFieldTypeBase.extend({ }); export const zSchedulerFieldValue = zSchedulerField.optional(); export const zSchedulerFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zSchedulerFieldType, value: zSchedulerFieldValue, }); -export const zSchedulerFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zSchedulerFieldType, -}); export const zSchedulerFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zSchedulerFieldType, default: zSchedulerFieldValue, @@ -631,7 +539,6 @@ export const zSchedulerFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type SchedulerFieldType = z.infer; export type SchedulerFieldValue = z.infer; export type SchedulerFieldInputInstance = z.infer; -export type SchedulerFieldOutputInstance = z.infer; export type SchedulerFieldInputTemplate = z.infer; export type SchedulerFieldOutputTemplate = z.infer; export const isSchedulerFieldInputInstance = (val: unknown): val is SchedulerFieldInputInstance => @@ -657,12 +564,8 @@ export const zStatelessFieldType = zFieldTypeBase.extend({ }); export const zStatelessFieldValue = z.undefined().catch(undefined); // stateless --> no value, but making this z.never() introduces a lot of extra TS fanagling export const zStatelessFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zStatelessFieldType, value: zStatelessFieldValue, }); -export const zStatelessFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zStatelessFieldType, -}); export const zStatelessFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zStatelessFieldType, default: zStatelessFieldValue, @@ -675,7 +578,6 @@ export const zStatelessFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type StatelessFieldType = z.infer; export type StatelessFieldValue = z.infer; export type StatelessFieldInputInstance = z.infer; -export type StatelessFieldOutputInstance = z.infer; export type StatelessFieldInputTemplate = z.infer; export type StatelessFieldOutputTemplate = z.infer; // #endregion @@ -783,36 +685,6 @@ export const isFieldInputInstance = (val: unknown): val is FieldInputInstance => zFieldInputInstance.safeParse(val).success; // #endregion -// #region StatefulFieldOutputInstance & FieldOutputInstance -export const zStatefulFieldOutputInstance = z.union([ - zIntegerFieldOutputInstance, - zFloatFieldOutputInstance, - zStringFieldOutputInstance, - zBooleanFieldOutputInstance, - zEnumFieldOutputInstance, - zImageFieldOutputInstance, - zBoardFieldOutputInstance, - zMainModelFieldOutputInstance, - zSDXLMainModelFieldOutputInstance, - zSDXLRefinerModelFieldOutputInstance, - zVAEModelFieldOutputInstance, - zLoRAModelFieldOutputInstance, - zControlNetModelFieldOutputInstance, - zIPAdapterModelFieldOutputInstance, - zT2IAdapterModelFieldOutputInstance, - zColorFieldOutputInstance, - zSchedulerFieldOutputInstance, -]); -export type StatefulFieldOutputInstance = z.infer; -export const isStatefulFieldOutputInstance = (val: unknown): val is StatefulFieldOutputInstance => - zStatefulFieldOutputInstance.safeParse(val).success; - -export const zFieldOutputInstance = z.union([zStatefulFieldOutputInstance, zStatelessFieldOutputInstance]); -export type FieldOutputInstance = z.infer; -export const isFieldOutputInstance = (val: unknown): val is FieldOutputInstance => - zFieldOutputInstance.safeParse(val).success; -// #endregion - // #region StatefulFieldInputTemplate & FieldInputTemplate export const zStatefulFieldInputTemplate = z.union([ zIntegerFieldInputTemplate, diff --git a/invokeai/frontend/web/src/features/nodes/types/invocation.ts b/invokeai/frontend/web/src/features/nodes/types/invocation.ts index 86ec70fd9bd..5ccb19430da 100644 --- a/invokeai/frontend/web/src/features/nodes/types/invocation.ts +++ b/invokeai/frontend/web/src/features/nodes/types/invocation.ts @@ -2,7 +2,7 @@ import type { Edge, Node } from 'reactflow'; import { z } from 'zod'; import { zClassification, zProgressImage } from './common'; -import { zFieldInputInstance, zFieldInputTemplate, zFieldOutputInstance, zFieldOutputTemplate } from './field'; +import { zFieldInputInstance, zFieldInputTemplate, zFieldOutputTemplate } from './field'; import { zSemVer } from './semver'; // #region InvocationTemplate @@ -25,16 +25,15 @@ export type InvocationTemplate = z.infer; // #region NodeData export const zInvocationNodeData = z.object({ id: z.string().trim().min(1), - type: z.string().trim().min(1), + version: zSemVer, + nodePack: z.string().min(1).nullish(), label: z.string(), - isOpen: z.boolean(), notes: z.string(), + type: z.string().trim().min(1), + inputs: z.record(zFieldInputInstance), + isOpen: z.boolean(), isIntermediate: z.boolean(), useCache: z.boolean(), - version: zSemVer, - nodePack: z.string().min(1).nullish(), - inputs: z.record(zFieldInputInstance), - outputs: z.record(zFieldOutputInstance), }); export const zNotesNodeData = z.object({ @@ -62,11 +61,12 @@ export type NotesNode = Node; export type CurrentImageNode = Node; export type AnyNode = Node; -export const isInvocationNode = (node?: AnyNode): node is InvocationNode => Boolean(node && node.type === 'invocation'); -export const isNotesNode = (node?: AnyNode): node is NotesNode => Boolean(node && node.type === 'notes'); -export const isCurrentImageNode = (node?: AnyNode): node is CurrentImageNode => +export const isInvocationNode = (node?: AnyNode | null): node is InvocationNode => + Boolean(node && node.type === 'invocation'); +export const isNotesNode = (node?: AnyNode | null): node is NotesNode => Boolean(node && node.type === 'notes'); +export const isCurrentImageNode = (node?: AnyNode | null): node is CurrentImageNode => Boolean(node && node.type === 'current_image'); -export const isInvocationNodeData = (node?: AnyNodeData): node is InvocationNodeData => +export const isInvocationNodeData = (node?: AnyNodeData | null): node is InvocationNodeData => Boolean(node && !['notes', 'current_image'].includes(node.type)); // node.type may be 'notes', 'current_image', or any invocation type // #endregion diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/common.ts b/invokeai/frontend/web/src/features/nodes/types/v2/common.ts new file mode 100644 index 00000000000..b5244743799 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/common.ts @@ -0,0 +1,188 @@ +import { z } from 'zod'; + +// #region Field data schemas +export const zImageField = z.object({ + image_name: z.string().trim().min(1), +}); +export type ImageField = z.infer; + +export const zBoardField = z.object({ + board_id: z.string().trim().min(1), +}); +export type BoardField = z.infer; + +export const zColorField = z.object({ + r: z.number().int().min(0).max(255), + g: z.number().int().min(0).max(255), + b: z.number().int().min(0).max(255), + a: z.number().int().min(0).max(255), +}); +export type ColorField = z.infer; + +export const zClassification = z.enum(['stable', 'beta', 'prototype']); +export type Classification = z.infer; + +export const zSchedulerField = z.enum([ + 'euler', + 'deis', + 'ddim', + 'ddpm', + 'dpmpp_2s', + 'dpmpp_2m', + 'dpmpp_2m_sde', + 'dpmpp_sde', + 'heun', + 'kdpm_2', + 'lms', + 'pndm', + 'unipc', + 'euler_k', + 'dpmpp_2s_k', + 'dpmpp_2m_k', + 'dpmpp_2m_sde_k', + 'dpmpp_sde_k', + 'heun_k', + 'lms_k', + 'euler_a', + 'kdpm_2_a', + 'lcm', +]); +export type SchedulerField = z.infer; +// #endregion + +// #region Model-related schemas +export const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']); +export const zModelType = z.enum(['main', 'vae', 'lora', 'controlnet', 'embedding']); +export const zModelName = z.string().min(3); +export const zModelIdentifier = z.object({ + model_name: zModelName, + base_model: zBaseModel, +}); +export type BaseModel = z.infer; +export type ModelType = z.infer; +export type ModelIdentifier = z.infer; + +export const zMainModelField = z.object({ + model_name: zModelName, + base_model: zBaseModel, + model_type: z.literal('main'), +}); +export const zSDXLRefinerModelField = z.object({ + model_name: z.string().min(1), + base_model: z.literal('sdxl-refiner'), + model_type: z.literal('main'), +}); +export type MainModelField = z.infer; +export type SDXLRefinerModelField = z.infer; + +export const zSubModelType = z.enum([ + 'unet', + 'text_encoder', + 'text_encoder_2', + 'tokenizer', + 'tokenizer_2', + 'vae', + 'vae_decoder', + 'vae_encoder', + 'scheduler', + 'safety_checker', +]); +export type SubModelType = z.infer; + +export const zVAEModelField = zModelIdentifier; + +export const zModelInfo = zModelIdentifier.extend({ + model_type: zModelType, + submodel: zSubModelType.optional(), +}); +export type ModelInfo = z.infer; + +export const zLoRAModelField = zModelIdentifier; +export type LoRAModelField = z.infer; + +export const zControlNetModelField = zModelIdentifier; +export type ControlNetModelField = z.infer; + +export const zIPAdapterModelField = zModelIdentifier; +export type IPAdapterModelField = z.infer; + +export const zT2IAdapterModelField = zModelIdentifier; +export type T2IAdapterModelField = z.infer; + +export const zLoraInfo = zModelInfo.extend({ + weight: z.number().optional(), +}); +export type LoraInfo = z.infer; + +export const zUNetField = z.object({ + unet: zModelInfo, + scheduler: zModelInfo, + loras: z.array(zLoraInfo), +}); +export type UNetField = z.infer; + +export const zCLIPField = z.object({ + tokenizer: zModelInfo, + text_encoder: zModelInfo, + skipped_layers: z.number(), + loras: z.array(zLoraInfo), +}); +export type CLIPField = z.infer; + +export const zVAEField = z.object({ + vae: zModelInfo, +}); +export type VAEField = z.infer; +// #endregion + +// #region Control Adapters +export const zControlField = z.object({ + image: zImageField, + control_model: zControlNetModelField, + control_weight: z.union([z.number(), z.array(z.number())]).optional(), + begin_step_percent: z.number().optional(), + end_step_percent: z.number().optional(), + control_mode: z.enum(['balanced', 'more_prompt', 'more_control', 'unbalanced']).optional(), + resize_mode: z.enum(['just_resize', 'crop_resize', 'fill_resize', 'just_resize_simple']).optional(), +}); +export type ControlField = z.infer; + +export const zIPAdapterField = z.object({ + image: zImageField, + ip_adapter_model: zIPAdapterModelField, + weight: z.number(), + begin_step_percent: z.number().optional(), + end_step_percent: z.number().optional(), +}); +export type IPAdapterField = z.infer; + +export const zT2IAdapterField = z.object({ + image: zImageField, + t2i_adapter_model: zT2IAdapterModelField, + weight: z.union([z.number(), z.array(z.number())]).optional(), + begin_step_percent: z.number().optional(), + end_step_percent: z.number().optional(), + resize_mode: z.enum(['just_resize', 'crop_resize', 'fill_resize', 'just_resize_simple']).optional(), +}); +export type T2IAdapterField = z.infer; +// #endregion + +// #region ProgressImage +export const zProgressImage = z.object({ + dataURL: z.string(), + width: z.number().int(), + height: z.number().int(), +}); +export type ProgressImage = z.infer; +// #endregion + +// #region ImageOutput +export const zImageOutput = z.object({ + image: zImageField, + width: z.number().int().gt(0), + height: z.number().int().gt(0), + type: z.literal('image_output'), +}); +export type ImageOutput = z.infer; +export const isImageOutput = (output: unknown): output is ImageOutput => zImageOutput.safeParse(output).success; +// #endregion diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/constants.ts b/invokeai/frontend/web/src/features/nodes/types/v2/constants.ts new file mode 100644 index 00000000000..35ef9e9fd2c --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/constants.ts @@ -0,0 +1,80 @@ +import type { Node } from 'reactflow'; + +/** + * How long to wait before showing a tooltip when hovering a field handle. + */ +export const HANDLE_TOOLTIP_OPEN_DELAY = 500; + +/** + * The width of a node in the UI in pixels. + */ +export const NODE_WIDTH = 320; + +/** + * This class name is special - reactflow uses it to identify the drag handle of a node, + * applying the appropriate listeners to it. + */ +export const DRAG_HANDLE_CLASSNAME = 'node-drag-handle'; + +/** + * reactflow-specifc properties shared between all node types. + */ +export const SHARED_NODE_PROPERTIES: Partial = { + dragHandle: `.${DRAG_HANDLE_CLASSNAME}`, +}; + +/** + * Helper for getting the kind of a field. + */ +export const KIND_MAP = { + input: 'inputs' as const, + output: 'outputs' as const, +}; + +/** + * Model types' handles are rendered as squares in the UI. + */ +export const MODEL_TYPES = [ + 'IPAdapterModelField', + 'ControlNetModelField', + 'LoRAModelField', + 'MainModelField', + 'SDXLMainModelField', + 'SDXLRefinerModelField', + 'VaeModelField', + 'UNetField', + 'VaeField', + 'ClipField', + 'T2IAdapterModelField', + 'IPAdapterModelField', +]; + +/** + * Colors for each field type - applies to their handles and edges. + */ +export const FIELD_COLORS: { [key: string]: string } = { + BoardField: 'purple.500', + BooleanField: 'green.500', + ClipField: 'green.500', + ColorField: 'pink.300', + ConditioningField: 'cyan.500', + ControlField: 'teal.500', + ControlNetModelField: 'teal.500', + EnumField: 'blue.500', + FloatField: 'orange.500', + ImageField: 'purple.500', + IntegerField: 'red.500', + IPAdapterField: 'teal.500', + IPAdapterModelField: 'teal.500', + LatentsField: 'pink.500', + LoRAModelField: 'teal.500', + MainModelField: 'teal.500', + SDXLMainModelField: 'teal.500', + SDXLRefinerModelField: 'teal.500', + StringField: 'yellow.500', + T2IAdapterField: 'teal.500', + T2IAdapterModelField: 'teal.500', + UNetField: 'red.500', + VaeField: 'blue.500', + VaeModelField: 'teal.500', +}; diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/error.ts b/invokeai/frontend/web/src/features/nodes/types/v2/error.ts new file mode 100644 index 00000000000..905b487fb04 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/error.ts @@ -0,0 +1,58 @@ +/** + * Invalid Workflow Version Error + * Raised when a workflow version is not recognized. + */ +export class WorkflowVersionError extends Error { + /** + * Create WorkflowVersionError + * @param {String} message + */ + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} +/** + * Workflow Migration Error + * Raised when a workflow migration fails. + */ +export class WorkflowMigrationError extends Error { + /** + * Create WorkflowMigrationError + * @param {String} message + */ + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} + +/** + * Unable to Update Node Error + * Raised when a node cannot be updated. + */ +export class NodeUpdateError extends Error { + /** + * Create NodeUpdateError + * @param {String} message + */ + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} + +/** + * FieldParseError + * Raised when a field cannot be parsed from a field schema. + */ +export class FieldParseError extends Error { + /** + * Create FieldTypeParseError + * @param {String} message + */ + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/field.ts b/invokeai/frontend/web/src/features/nodes/types/v2/field.ts new file mode 100644 index 00000000000..38f1af55dd8 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/field.ts @@ -0,0 +1,875 @@ +import { z } from 'zod'; + +import { + zBoardField, + zColorField, + zControlNetModelField, + zImageField, + zIPAdapterModelField, + zLoRAModelField, + zMainModelField, + zSchedulerField, + zT2IAdapterModelField, + zVAEModelField, +} from './common'; + +/** + * zod schemas & inferred types for fields. + * + * These schemas and types are only required for stateful field - fields that have UI components + * and allow the user to directly provide values. + * + * This includes primitive values (numbers, strings, booleans), models, scheduler, etc. + * + * If a field type does not have a UI component, then it does not need to be included here, because + * we never store its value. Such field types will be handled via the "StatelessField" logic. + * + * Fields require: + * - zFieldType - zod schema for the field type + * - zFieldValue - zod schema for the field value + * - zFieldInputInstance - zod schema for the field's input instance + * - zFieldOutputInstance - zod schema for the field's output instance + * - zFieldInputTemplate - zod schema for the field's input template + * - zFieldOutputTemplate - zod schema for the field's output template + * - inferred types for each schema + * - type guards for InputInstance and InputTemplate + * + * These then must be added to the unions at the bottom of this file. + */ + +/** */ + +// #region Base schemas & misc +export const zFieldInput = z.enum(['connection', 'direct', 'any']); +export type FieldInput = z.infer; + +export const zFieldUIComponent = z.enum(['none', 'textarea', 'slider']); +export type FieldUIComponent = z.infer; + +export const zFieldInstanceBase = z.object({ + id: z.string().trim().min(1), + name: z.string().trim().min(1), +}); +export const zFieldInputInstanceBase = zFieldInstanceBase.extend({ + fieldKind: z.literal('input'), + label: z.string().nullish(), +}); +export const zFieldOutputInstanceBase = zFieldInstanceBase.extend({ + fieldKind: z.literal('output'), +}); +export type FieldInstanceBase = z.infer; +export type FieldInputInstanceBase = z.infer; +export type FieldOutputInstanceBase = z.infer; + +export const zFieldTemplateBase = z.object({ + name: z.string().min(1), + title: z.string().min(1), + description: z.string().nullish(), + ui_hidden: z.boolean(), + ui_type: z.string().nullish(), + ui_order: z.number().int().nullish(), +}); +export const zFieldInputTemplateBase = zFieldTemplateBase.extend({ + fieldKind: z.literal('input'), + input: zFieldInput, + required: z.boolean(), + ui_component: zFieldUIComponent.nullish(), + ui_choice_labels: z.record(z.string()).nullish(), +}); +export const zFieldOutputTemplateBase = zFieldTemplateBase.extend({ + fieldKind: z.literal('output'), +}); +export type FieldTemplateBase = z.infer; +export type FieldInputTemplateBase = z.infer; +export type FieldOutputTemplateBase = z.infer; + +export const zFieldTypeBase = z.object({ + isCollection: z.boolean(), + isCollectionOrScalar: z.boolean(), +}); + +export const zFieldIdentifier = z.object({ + nodeId: z.string().trim().min(1), + fieldName: z.string().trim().min(1), +}); +export type FieldIdentifier = z.infer; +export const isFieldIdentifier = (val: unknown): val is FieldIdentifier => zFieldIdentifier.safeParse(val).success; +// #endregion + +// #region IntegerField +export const zIntegerFieldType = zFieldTypeBase.extend({ + name: z.literal('IntegerField'), +}); +export const zIntegerFieldValue = z.number().int(); +export const zIntegerFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zIntegerFieldType, + value: zIntegerFieldValue, +}); +export const zIntegerFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zIntegerFieldType, +}); +export const zIntegerFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zIntegerFieldType, + default: zIntegerFieldValue, + multipleOf: z.number().int().optional(), + maximum: z.number().int().optional(), + exclusiveMaximum: z.number().int().optional(), + minimum: z.number().int().optional(), + exclusiveMinimum: z.number().int().optional(), +}); +export const zIntegerFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zIntegerFieldType, +}); +export type IntegerFieldType = z.infer; +export type IntegerFieldValue = z.infer; +export type IntegerFieldInputInstance = z.infer; +export type IntegerFieldInputTemplate = z.infer; +export const isIntegerFieldInputInstance = (val: unknown): val is IntegerFieldInputInstance => + zIntegerFieldInputInstance.safeParse(val).success; +export const isIntegerFieldInputTemplate = (val: unknown): val is IntegerFieldInputTemplate => + zIntegerFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region FloatField +export const zFloatFieldType = zFieldTypeBase.extend({ + name: z.literal('FloatField'), +}); +export const zFloatFieldValue = z.number(); +export const zFloatFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zFloatFieldType, + value: zFloatFieldValue, +}); +export const zFloatFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zFloatFieldType, +}); +export const zFloatFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zFloatFieldType, + default: zFloatFieldValue, + multipleOf: z.number().optional(), + maximum: z.number().optional(), + exclusiveMaximum: z.number().optional(), + minimum: z.number().optional(), + exclusiveMinimum: z.number().optional(), +}); +export const zFloatFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zFloatFieldType, +}); +export type FloatFieldType = z.infer; +export type FloatFieldValue = z.infer; +export type FloatFieldInputInstance = z.infer; +export type FloatFieldOutputInstance = z.infer; +export type FloatFieldInputTemplate = z.infer; +export type FloatFieldOutputTemplate = z.infer; +export const isFloatFieldInputInstance = (val: unknown): val is FloatFieldInputInstance => + zFloatFieldInputInstance.safeParse(val).success; +export const isFloatFieldInputTemplate = (val: unknown): val is FloatFieldInputTemplate => + zFloatFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region StringField +export const zStringFieldType = zFieldTypeBase.extend({ + name: z.literal('StringField'), +}); +export const zStringFieldValue = z.string(); +export const zStringFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zStringFieldType, + value: zStringFieldValue, +}); +export const zStringFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zStringFieldType, +}); +export const zStringFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zStringFieldType, + default: zStringFieldValue, + maxLength: z.number().int().optional(), + minLength: z.number().int().optional(), +}); +export const zStringFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zStringFieldType, +}); + +export type StringFieldType = z.infer; +export type StringFieldValue = z.infer; +export type StringFieldInputInstance = z.infer; +export type StringFieldOutputInstance = z.infer; +export type StringFieldInputTemplate = z.infer; +export type StringFieldOutputTemplate = z.infer; +export const isStringFieldInputInstance = (val: unknown): val is StringFieldInputInstance => + zStringFieldInputInstance.safeParse(val).success; +export const isStringFieldInputTemplate = (val: unknown): val is StringFieldInputTemplate => + zStringFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region BooleanField +export const zBooleanFieldType = zFieldTypeBase.extend({ + name: z.literal('BooleanField'), +}); +export const zBooleanFieldValue = z.boolean(); +export const zBooleanFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zBooleanFieldType, + value: zBooleanFieldValue, +}); +export const zBooleanFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zBooleanFieldType, +}); +export const zBooleanFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zBooleanFieldType, + default: zBooleanFieldValue, +}); +export const zBooleanFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zBooleanFieldType, +}); +export type BooleanFieldType = z.infer; +export type BooleanFieldValue = z.infer; +export type BooleanFieldInputInstance = z.infer; +export type BooleanFieldOutputInstance = z.infer; +export type BooleanFieldInputTemplate = z.infer; +export type BooleanFieldOutputTemplate = z.infer; +export const isBooleanFieldInputInstance = (val: unknown): val is BooleanFieldInputInstance => + zBooleanFieldInputInstance.safeParse(val).success; +export const isBooleanFieldInputTemplate = (val: unknown): val is BooleanFieldInputTemplate => + zBooleanFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region EnumField +export const zEnumFieldType = zFieldTypeBase.extend({ + name: z.literal('EnumField'), +}); +export const zEnumFieldValue = z.string(); +export const zEnumFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zEnumFieldType, + value: zEnumFieldValue, +}); +export const zEnumFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zEnumFieldType, +}); +export const zEnumFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zEnumFieldType, + default: zEnumFieldValue, + options: z.array(z.string()), + labels: z.record(z.string()).optional(), +}); +export const zEnumFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zEnumFieldType, +}); +export type EnumFieldType = z.infer; +export type EnumFieldValue = z.infer; +export type EnumFieldInputInstance = z.infer; +export type EnumFieldOutputInstance = z.infer; +export type EnumFieldInputTemplate = z.infer; +export type EnumFieldOutputTemplate = z.infer; +export const isEnumFieldInputInstance = (val: unknown): val is EnumFieldInputInstance => + zEnumFieldInputInstance.safeParse(val).success; +export const isEnumFieldInputTemplate = (val: unknown): val is EnumFieldInputTemplate => + zEnumFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region ImageField +export const zImageFieldType = zFieldTypeBase.extend({ + name: z.literal('ImageField'), +}); +export const zImageFieldValue = zImageField.optional(); +export const zImageFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zImageFieldType, + value: zImageFieldValue, +}); +export const zImageFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zImageFieldType, +}); +export const zImageFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zImageFieldType, + default: zImageFieldValue, +}); +export const zImageFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zImageFieldType, +}); +export type ImageFieldType = z.infer; +export type ImageFieldValue = z.infer; +export type ImageFieldInputInstance = z.infer; +export type ImageFieldOutputInstance = z.infer; +export type ImageFieldInputTemplate = z.infer; +export type ImageFieldOutputTemplate = z.infer; +export const isImageFieldInputInstance = (val: unknown): val is ImageFieldInputInstance => + zImageFieldInputInstance.safeParse(val).success; +export const isImageFieldInputTemplate = (val: unknown): val is ImageFieldInputTemplate => + zImageFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region BoardField +export const zBoardFieldType = zFieldTypeBase.extend({ + name: z.literal('BoardField'), +}); +export const zBoardFieldValue = zBoardField.optional(); +export const zBoardFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zBoardFieldType, + value: zBoardFieldValue, +}); +export const zBoardFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zBoardFieldType, +}); +export const zBoardFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zBoardFieldType, + default: zBoardFieldValue, +}); +export const zBoardFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zBoardFieldType, +}); +export type BoardFieldType = z.infer; +export type BoardFieldValue = z.infer; +export type BoardFieldInputInstance = z.infer; +export type BoardFieldOutputInstance = z.infer; +export type BoardFieldInputTemplate = z.infer; +export type BoardFieldOutputTemplate = z.infer; +export const isBoardFieldInputInstance = (val: unknown): val is BoardFieldInputInstance => + zBoardFieldInputInstance.safeParse(val).success; +export const isBoardFieldInputTemplate = (val: unknown): val is BoardFieldInputTemplate => + zBoardFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region ColorField +export const zColorFieldType = zFieldTypeBase.extend({ + name: z.literal('ColorField'), +}); +export const zColorFieldValue = zColorField.optional(); +export const zColorFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zColorFieldType, + value: zColorFieldValue, +}); +export const zColorFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zColorFieldType, +}); +export const zColorFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zColorFieldType, + default: zColorFieldValue, +}); +export const zColorFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zColorFieldType, +}); +export type ColorFieldType = z.infer; +export type ColorFieldValue = z.infer; +export type ColorFieldInputInstance = z.infer; +export type ColorFieldOutputInstance = z.infer; +export type ColorFieldInputTemplate = z.infer; +export type ColorFieldOutputTemplate = z.infer; +export const isColorFieldInputInstance = (val: unknown): val is ColorFieldInputInstance => + zColorFieldInputInstance.safeParse(val).success; +export const isColorFieldInputTemplate = (val: unknown): val is ColorFieldInputTemplate => + zColorFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region MainModelField +export const zMainModelFieldType = zFieldTypeBase.extend({ + name: z.literal('MainModelField'), +}); +export const zMainModelFieldValue = zMainModelField.optional(); +export const zMainModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zMainModelFieldType, + value: zMainModelFieldValue, +}); +export const zMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zMainModelFieldType, +}); +export const zMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zMainModelFieldType, + default: zMainModelFieldValue, +}); +export const zMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zMainModelFieldType, +}); +export type MainModelFieldType = z.infer; +export type MainModelFieldValue = z.infer; +export type MainModelFieldInputInstance = z.infer; +export type MainModelFieldOutputInstance = z.infer; +export type MainModelFieldInputTemplate = z.infer; +export type MainModelFieldOutputTemplate = z.infer; +export const isMainModelFieldInputInstance = (val: unknown): val is MainModelFieldInputInstance => + zMainModelFieldInputInstance.safeParse(val).success; +export const isMainModelFieldInputTemplate = (val: unknown): val is MainModelFieldInputTemplate => + zMainModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region SDXLMainModelField +export const zSDXLMainModelFieldType = zFieldTypeBase.extend({ + name: z.literal('SDXLMainModelField'), +}); +export const zSDXLMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only. +export const zSDXLMainModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zSDXLMainModelFieldType, + value: zSDXLMainModelFieldValue, +}); +export const zSDXLMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zSDXLMainModelFieldType, +}); +export const zSDXLMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zSDXLMainModelFieldType, + default: zSDXLMainModelFieldValue, +}); +export const zSDXLMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zSDXLMainModelFieldType, +}); +export type SDXLMainModelFieldType = z.infer; +export type SDXLMainModelFieldValue = z.infer; +export type SDXLMainModelFieldInputInstance = z.infer; +export type SDXLMainModelFieldOutputInstance = z.infer; +export type SDXLMainModelFieldInputTemplate = z.infer; +export type SDXLMainModelFieldOutputTemplate = z.infer; +export const isSDXLMainModelFieldInputInstance = (val: unknown): val is SDXLMainModelFieldInputInstance => + zSDXLMainModelFieldInputInstance.safeParse(val).success; +export const isSDXLMainModelFieldInputTemplate = (val: unknown): val is SDXLMainModelFieldInputTemplate => + zSDXLMainModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region SDXLRefinerModelField +export const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({ + name: z.literal('SDXLRefinerModelField'), +}); +export const zSDXLRefinerModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL Refiner models only. +export const zSDXLRefinerModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zSDXLRefinerModelFieldType, + value: zSDXLRefinerModelFieldValue, +}); +export const zSDXLRefinerModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zSDXLRefinerModelFieldType, +}); +export const zSDXLRefinerModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zSDXLRefinerModelFieldType, + default: zSDXLRefinerModelFieldValue, +}); +export const zSDXLRefinerModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zSDXLRefinerModelFieldType, +}); +export type SDXLRefinerModelFieldType = z.infer; +export type SDXLRefinerModelFieldValue = z.infer; +export type SDXLRefinerModelFieldInputInstance = z.infer; +export type SDXLRefinerModelFieldOutputInstance = z.infer; +export type SDXLRefinerModelFieldInputTemplate = z.infer; +export type SDXLRefinerModelFieldOutputTemplate = z.infer; +export const isSDXLRefinerModelFieldInputInstance = (val: unknown): val is SDXLRefinerModelFieldInputInstance => + zSDXLRefinerModelFieldInputInstance.safeParse(val).success; +export const isSDXLRefinerModelFieldInputTemplate = (val: unknown): val is SDXLRefinerModelFieldInputTemplate => + zSDXLRefinerModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region VAEModelField +export const zVAEModelFieldType = zFieldTypeBase.extend({ + name: z.literal('VAEModelField'), +}); +export const zVAEModelFieldValue = zVAEModelField.optional(); +export const zVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zVAEModelFieldType, + value: zVAEModelFieldValue, +}); +export const zVAEModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zVAEModelFieldType, +}); +export const zVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zVAEModelFieldType, + default: zVAEModelFieldValue, +}); +export const zVAEModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zVAEModelFieldType, +}); +export type VAEModelFieldType = z.infer; +export type VAEModelFieldValue = z.infer; +export type VAEModelFieldInputInstance = z.infer; +export type VAEModelFieldOutputInstance = z.infer; +export type VAEModelFieldInputTemplate = z.infer; +export type VAEModelFieldOutputTemplate = z.infer; +export const isVAEModelFieldInputInstance = (val: unknown): val is VAEModelFieldInputInstance => + zVAEModelFieldInputInstance.safeParse(val).success; +export const isVAEModelFieldInputTemplate = (val: unknown): val is VAEModelFieldInputTemplate => + zVAEModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region LoRAModelField +export const zLoRAModelFieldType = zFieldTypeBase.extend({ + name: z.literal('LoRAModelField'), +}); +export const zLoRAModelFieldValue = zLoRAModelField.optional(); +export const zLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zLoRAModelFieldType, + value: zLoRAModelFieldValue, +}); +export const zLoRAModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zLoRAModelFieldType, +}); +export const zLoRAModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zLoRAModelFieldType, + default: zLoRAModelFieldValue, +}); +export const zLoRAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zLoRAModelFieldType, +}); +export type LoRAModelFieldType = z.infer; +export type LoRAModelFieldValue = z.infer; +export type LoRAModelFieldInputInstance = z.infer; +export type LoRAModelFieldOutputInstance = z.infer; +export type LoRAModelFieldInputTemplate = z.infer; +export type LoRAModelFieldOutputTemplate = z.infer; +export const isLoRAModelFieldInputInstance = (val: unknown): val is LoRAModelFieldInputInstance => + zLoRAModelFieldInputInstance.safeParse(val).success; +export const isLoRAModelFieldInputTemplate = (val: unknown): val is LoRAModelFieldInputTemplate => + zLoRAModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region ControlNetModelField +export const zControlNetModelFieldType = zFieldTypeBase.extend({ + name: z.literal('ControlNetModelField'), +}); +export const zControlNetModelFieldValue = zControlNetModelField.optional(); +export const zControlNetModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zControlNetModelFieldType, + value: zControlNetModelFieldValue, +}); +export const zControlNetModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zControlNetModelFieldType, +}); +export const zControlNetModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zControlNetModelFieldType, + default: zControlNetModelFieldValue, +}); +export const zControlNetModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zControlNetModelFieldType, +}); +export type ControlNetModelFieldType = z.infer; +export type ControlNetModelFieldValue = z.infer; +export type ControlNetModelFieldInputInstance = z.infer; +export type ControlNetModelFieldOutputInstance = z.infer; +export type ControlNetModelFieldInputTemplate = z.infer; +export type ControlNetModelFieldOutputTemplate = z.infer; +export const isControlNetModelFieldInputInstance = (val: unknown): val is ControlNetModelFieldInputInstance => + zControlNetModelFieldInputInstance.safeParse(val).success; +export const isControlNetModelFieldInputTemplate = (val: unknown): val is ControlNetModelFieldInputTemplate => + zControlNetModelFieldInputTemplate.safeParse(val).success; +export const isControlNetModelFieldValue = (v: unknown): v is ControlNetModelFieldValue => + zControlNetModelFieldValue.safeParse(v).success; +// #endregion + +// #region IPAdapterModelField +export const zIPAdapterModelFieldType = zFieldTypeBase.extend({ + name: z.literal('IPAdapterModelField'), +}); +export const zIPAdapterModelFieldValue = zIPAdapterModelField.optional(); +export const zIPAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zIPAdapterModelFieldType, + value: zIPAdapterModelFieldValue, +}); +export const zIPAdapterModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zIPAdapterModelFieldType, +}); +export const zIPAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zIPAdapterModelFieldType, + default: zIPAdapterModelFieldValue, +}); +export const zIPAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zIPAdapterModelFieldType, +}); +export type IPAdapterModelFieldType = z.infer; +export type IPAdapterModelFieldValue = z.infer; +export type IPAdapterModelFieldInputInstance = z.infer; +export type IPAdapterModelFieldOutputInstance = z.infer; +export type IPAdapterModelFieldInputTemplate = z.infer; +export type IPAdapterModelFieldOutputTemplate = z.infer; +export const isIPAdapterModelFieldInputInstance = (val: unknown): val is IPAdapterModelFieldInputInstance => + zIPAdapterModelFieldInputInstance.safeParse(val).success; +export const isIPAdapterModelFieldInputTemplate = (val: unknown): val is IPAdapterModelFieldInputTemplate => + zIPAdapterModelFieldInputTemplate.safeParse(val).success; +export const isIPAdapterModelFieldValue = (val: unknown): val is IPAdapterModelFieldValue => + zIPAdapterModelFieldValue.safeParse(val).success; +// #endregion + +// #region T2IAdapterField +export const zT2IAdapterModelFieldType = zFieldTypeBase.extend({ + name: z.literal('T2IAdapterModelField'), +}); +export const zT2IAdapterModelFieldValue = zT2IAdapterModelField.optional(); +export const zT2IAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zT2IAdapterModelFieldType, + value: zT2IAdapterModelFieldValue, +}); +export const zT2IAdapterModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zT2IAdapterModelFieldType, +}); +export const zT2IAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zT2IAdapterModelFieldType, + default: zT2IAdapterModelFieldValue, +}); +export const zT2IAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zT2IAdapterModelFieldType, +}); +export type T2IAdapterModelFieldType = z.infer; +export type T2IAdapterModelFieldValue = z.infer; +export type T2IAdapterModelFieldInputInstance = z.infer; +export type T2IAdapterModelFieldOutputInstance = z.infer; +export type T2IAdapterModelFieldInputTemplate = z.infer; +export type T2IAdapterModelFieldOutputTemplate = z.infer; +export const isT2IAdapterModelFieldInputInstance = (val: unknown): val is T2IAdapterModelFieldInputInstance => + zT2IAdapterModelFieldInputInstance.safeParse(val).success; +export const isT2IAdapterModelFieldInputTemplate = (val: unknown): val is T2IAdapterModelFieldInputTemplate => + zT2IAdapterModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region SchedulerField +export const zSchedulerFieldType = zFieldTypeBase.extend({ + name: z.literal('SchedulerField'), +}); +export const zSchedulerFieldValue = zSchedulerField.optional(); +export const zSchedulerFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zSchedulerFieldType, + value: zSchedulerFieldValue, +}); +export const zSchedulerFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zSchedulerFieldType, +}); +export const zSchedulerFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zSchedulerFieldType, + default: zSchedulerFieldValue, +}); +export const zSchedulerFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zSchedulerFieldType, +}); +export type SchedulerFieldType = z.infer; +export type SchedulerFieldValue = z.infer; +export type SchedulerFieldInputInstance = z.infer; +export type SchedulerFieldOutputInstance = z.infer; +export type SchedulerFieldInputTemplate = z.infer; +export type SchedulerFieldOutputTemplate = z.infer; +export const isSchedulerFieldInputInstance = (val: unknown): val is SchedulerFieldInputInstance => + zSchedulerFieldInputInstance.safeParse(val).success; +export const isSchedulerFieldInputTemplate = (val: unknown): val is SchedulerFieldInputTemplate => + zSchedulerFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region StatelessField +/** + * StatelessField is a catchall for stateless fields with no UI input components. They do not + * do not support "direct" input, instead only accepting connections from other fields. + * + * This field type serves as a "generic" field type. + * + * Examples include: + * - Fields like UNetField or LatentsField where we do not allow direct UI input + * - Reserved fields like IsIntermediate + * - Any other field we don't have full-on schemas for + */ +export const zStatelessFieldType = zFieldTypeBase.extend({ + name: z.string().min(1), // stateless --> we accept the field's name as the type +}); +export const zStatelessFieldValue = z.undefined().catch(undefined); // stateless --> no value, but making this z.never() introduces a lot of extra TS fanagling +export const zStatelessFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zStatelessFieldType, + value: zStatelessFieldValue, +}); +export const zStatelessFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zStatelessFieldType, +}); +export const zStatelessFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zStatelessFieldType, + default: zStatelessFieldValue, + input: z.literal('connection'), // stateless --> only accepts connection inputs +}); +export const zStatelessFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zStatelessFieldType, +}); + +export type StatelessFieldType = z.infer; +export type StatelessFieldValue = z.infer; +export type StatelessFieldInputInstance = z.infer; +export type StatelessFieldOutputInstance = z.infer; +export type StatelessFieldInputTemplate = z.infer; +export type StatelessFieldOutputTemplate = z.infer; +// #endregion + +/** + * Here we define the main field unions: + * - FieldType + * - FieldValue + * - FieldInputInstance + * - FieldOutputInstance + * - FieldInputTemplate + * - FieldOutputTemplate + * + * All stateful fields are unioned together, and then that union is unioned with StatelessField. + * + * This allows us to interact with stateful fields without needing to worry about "generic" handling + * for all other StatelessFields. + */ + +// #region StatefulFieldType & FieldType +export const zStatefulFieldType = z.union([ + zIntegerFieldType, + zFloatFieldType, + zStringFieldType, + zBooleanFieldType, + zEnumFieldType, + zImageFieldType, + zBoardFieldType, + zMainModelFieldType, + zSDXLMainModelFieldType, + zSDXLRefinerModelFieldType, + zVAEModelFieldType, + zLoRAModelFieldType, + zControlNetModelFieldType, + zIPAdapterModelFieldType, + zT2IAdapterModelFieldType, + zColorFieldType, + zSchedulerFieldType, +]); +export type StatefulFieldType = z.infer; +export const isStatefulFieldType = (val: unknown): val is StatefulFieldType => + zStatefulFieldType.safeParse(val).success; + +export const zFieldType = z.union([zStatefulFieldType, zStatelessFieldType]); +export type FieldType = z.infer; +export const isFieldType = (val: unknown): val is FieldType => zFieldType.safeParse(val).success; +// #endregion + +// #region StatefulFieldValue & FieldValue +export const zStatefulFieldValue = z.union([ + zIntegerFieldValue, + zFloatFieldValue, + zStringFieldValue, + zBooleanFieldValue, + zEnumFieldValue, + zImageFieldValue, + zBoardFieldValue, + zMainModelFieldValue, + zSDXLMainModelFieldValue, + zSDXLRefinerModelFieldValue, + zVAEModelFieldValue, + zLoRAModelFieldValue, + zControlNetModelFieldValue, + zIPAdapterModelFieldValue, + zT2IAdapterModelFieldValue, + zColorFieldValue, + zSchedulerFieldValue, +]); +export type StatefulFieldValue = z.infer; +export const isStatefulFieldValue = (val: unknown): val is StatefulFieldValue => + zStatefulFieldValue.safeParse(val).success; + +export const zFieldValue = z.union([zStatefulFieldValue, zStatelessFieldValue]); +export type FieldValue = z.infer; +export const isFieldValue = (val: unknown): val is FieldValue => zFieldValue.safeParse(val).success; +// #endregion + +// #region StatefulFieldInputInstance & FieldInputInstance +export const zStatefulFieldInputInstance = z.union([ + zIntegerFieldInputInstance, + zFloatFieldInputInstance, + zStringFieldInputInstance, + zBooleanFieldInputInstance, + zEnumFieldInputInstance, + zImageFieldInputInstance, + zBoardFieldInputInstance, + zMainModelFieldInputInstance, + zSDXLMainModelFieldInputInstance, + zSDXLRefinerModelFieldInputInstance, + zVAEModelFieldInputInstance, + zLoRAModelFieldInputInstance, + zControlNetModelFieldInputInstance, + zIPAdapterModelFieldInputInstance, + zT2IAdapterModelFieldInputInstance, + zColorFieldInputInstance, + zSchedulerFieldInputInstance, +]); +export type StatefulFieldInputInstance = z.infer; +export const isStatefulFieldInputInstance = (val: unknown): val is StatefulFieldInputInstance => + zStatefulFieldInputInstance.safeParse(val).success; + +export const zFieldInputInstance = z.union([zStatefulFieldInputInstance, zStatelessFieldInputInstance]); +export type FieldInputInstance = z.infer; +export const isFieldInputInstance = (val: unknown): val is FieldInputInstance => + zFieldInputInstance.safeParse(val).success; +// #endregion + +// #region StatefulFieldOutputInstance & FieldOutputInstance +export const zStatefulFieldOutputInstance = z.union([ + zIntegerFieldOutputInstance, + zFloatFieldOutputInstance, + zStringFieldOutputInstance, + zBooleanFieldOutputInstance, + zEnumFieldOutputInstance, + zImageFieldOutputInstance, + zBoardFieldOutputInstance, + zMainModelFieldOutputInstance, + zSDXLMainModelFieldOutputInstance, + zSDXLRefinerModelFieldOutputInstance, + zVAEModelFieldOutputInstance, + zLoRAModelFieldOutputInstance, + zControlNetModelFieldOutputInstance, + zIPAdapterModelFieldOutputInstance, + zT2IAdapterModelFieldOutputInstance, + zColorFieldOutputInstance, + zSchedulerFieldOutputInstance, +]); +export type StatefulFieldOutputInstance = z.infer; +export const isStatefulFieldOutputInstance = (val: unknown): val is StatefulFieldOutputInstance => + zStatefulFieldOutputInstance.safeParse(val).success; + +export const zFieldOutputInstance = z.union([zStatefulFieldOutputInstance, zStatelessFieldOutputInstance]); +export type FieldOutputInstance = z.infer; +export const isFieldOutputInstance = (val: unknown): val is FieldOutputInstance => + zFieldOutputInstance.safeParse(val).success; +// #endregion + +// #region StatefulFieldInputTemplate & FieldInputTemplate +export const zStatefulFieldInputTemplate = z.union([ + zIntegerFieldInputTemplate, + zFloatFieldInputTemplate, + zStringFieldInputTemplate, + zBooleanFieldInputTemplate, + zEnumFieldInputTemplate, + zImageFieldInputTemplate, + zBoardFieldInputTemplate, + zMainModelFieldInputTemplate, + zSDXLMainModelFieldInputTemplate, + zSDXLRefinerModelFieldInputTemplate, + zVAEModelFieldInputTemplate, + zLoRAModelFieldInputTemplate, + zControlNetModelFieldInputTemplate, + zIPAdapterModelFieldInputTemplate, + zT2IAdapterModelFieldInputTemplate, + zColorFieldInputTemplate, + zSchedulerFieldInputTemplate, + zStatelessFieldInputTemplate, +]); +export type StatefulFieldInputTemplate = z.infer; +export const isStatefulFieldInputTemplate = (val: unknown): val is StatefulFieldInputTemplate => + zStatefulFieldInputTemplate.safeParse(val).success; + +export const zFieldInputTemplate = z.union([zStatefulFieldInputTemplate, zStatelessFieldInputTemplate]); +export type FieldInputTemplate = z.infer; +export const isFieldInputTemplate = (val: unknown): val is FieldInputTemplate => + zFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region StatefulFieldOutputTemplate & FieldOutputTemplate +export const zStatefulFieldOutputTemplate = z.union([ + zIntegerFieldOutputTemplate, + zFloatFieldOutputTemplate, + zStringFieldOutputTemplate, + zBooleanFieldOutputTemplate, + zEnumFieldOutputTemplate, + zImageFieldOutputTemplate, + zBoardFieldOutputTemplate, + zMainModelFieldOutputTemplate, + zSDXLMainModelFieldOutputTemplate, + zSDXLRefinerModelFieldOutputTemplate, + zVAEModelFieldOutputTemplate, + zLoRAModelFieldOutputTemplate, + zControlNetModelFieldOutputTemplate, + zIPAdapterModelFieldOutputTemplate, + zT2IAdapterModelFieldOutputTemplate, + zColorFieldOutputTemplate, + zSchedulerFieldOutputTemplate, +]); +export type StatefulFieldOutputTemplate = z.infer; +export const isStatefulFieldOutputTemplate = (val: unknown): val is StatefulFieldOutputTemplate => + zStatefulFieldOutputTemplate.safeParse(val).success; + +export const zFieldOutputTemplate = z.union([zStatefulFieldOutputTemplate, zStatelessFieldOutputTemplate]); +export type FieldOutputTemplate = z.infer; +export const isFieldOutputTemplate = (val: unknown): val is FieldOutputTemplate => + zFieldOutputTemplate.safeParse(val).success; +// #endregion diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/invocation.ts b/invokeai/frontend/web/src/features/nodes/types/v2/invocation.ts new file mode 100644 index 00000000000..86ec70fd9bd --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/invocation.ts @@ -0,0 +1,93 @@ +import type { Edge, Node } from 'reactflow'; +import { z } from 'zod'; + +import { zClassification, zProgressImage } from './common'; +import { zFieldInputInstance, zFieldInputTemplate, zFieldOutputInstance, zFieldOutputTemplate } from './field'; +import { zSemVer } from './semver'; + +// #region InvocationTemplate +export const zInvocationTemplate = z.object({ + type: z.string(), + title: z.string(), + description: z.string(), + tags: z.array(z.string().min(1)), + inputs: z.record(zFieldInputTemplate), + outputs: z.record(zFieldOutputTemplate), + outputType: z.string().min(1), + version: zSemVer, + useCache: z.boolean(), + nodePack: z.string().min(1).nullish(), + classification: zClassification, +}); +export type InvocationTemplate = z.infer; +// #endregion + +// #region NodeData +export const zInvocationNodeData = z.object({ + id: z.string().trim().min(1), + type: z.string().trim().min(1), + label: z.string(), + isOpen: z.boolean(), + notes: z.string(), + isIntermediate: z.boolean(), + useCache: z.boolean(), + version: zSemVer, + nodePack: z.string().min(1).nullish(), + inputs: z.record(zFieldInputInstance), + outputs: z.record(zFieldOutputInstance), +}); + +export const zNotesNodeData = z.object({ + id: z.string().trim().min(1), + type: z.literal('notes'), + label: z.string(), + isOpen: z.boolean(), + notes: z.string(), +}); +export const zCurrentImageNodeData = z.object({ + id: z.string().trim().min(1), + type: z.literal('current_image'), + label: z.string(), + isOpen: z.boolean(), +}); +export const zAnyNodeData = z.union([zInvocationNodeData, zNotesNodeData, zCurrentImageNodeData]); + +export type NotesNodeData = z.infer; +export type InvocationNodeData = z.infer; +export type CurrentImageNodeData = z.infer; +export type AnyNodeData = z.infer; + +export type InvocationNode = Node; +export type NotesNode = Node; +export type CurrentImageNode = Node; +export type AnyNode = Node; + +export const isInvocationNode = (node?: AnyNode): node is InvocationNode => Boolean(node && node.type === 'invocation'); +export const isNotesNode = (node?: AnyNode): node is NotesNode => Boolean(node && node.type === 'notes'); +export const isCurrentImageNode = (node?: AnyNode): node is CurrentImageNode => + Boolean(node && node.type === 'current_image'); +export const isInvocationNodeData = (node?: AnyNodeData): node is InvocationNodeData => + Boolean(node && !['notes', 'current_image'].includes(node.type)); // node.type may be 'notes', 'current_image', or any invocation type +// #endregion + +// #region NodeExecutionState +export const zNodeStatus = z.enum(['PENDING', 'IN_PROGRESS', 'COMPLETED', 'FAILED']); +export const zNodeExecutionState = z.object({ + nodeId: z.string().trim().min(1), + status: zNodeStatus, + progress: z.number().nullable(), + progressImage: zProgressImage.nullable(), + error: z.string().nullable(), + outputs: z.array(z.any()), +}); +export type NodeExecutionState = z.infer; +export type NodeStatus = z.infer; +// #endregion + +// #region Edges +export const zInvocationNodeEdgeExtra = z.object({ + type: z.union([z.literal('default'), z.literal('collapsed')]), +}); +export type InvocationNodeEdgeExtra = z.infer; +export type InvocationNodeEdge = Edge; +// #endregion diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/metadata.ts b/invokeai/frontend/web/src/features/nodes/types/v2/metadata.ts new file mode 100644 index 00000000000..0cc30499e38 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/metadata.ts @@ -0,0 +1,77 @@ +import { z } from 'zod'; + +import { + zControlField, + zIPAdapterField, + zLoRAModelField, + zMainModelField, + zSDXLRefinerModelField, + zT2IAdapterField, + zVAEModelField, +} from './common'; + +// #region Metadata-optimized versions of schemas +// TODO: It's possible that `deepPartial` will be deprecated: +// - https://github.com/colinhacks/zod/issues/2106 +// - https://github.com/colinhacks/zod/issues/2854 +export const zLoRAMetadataItem = z.object({ + lora: zLoRAModelField.deepPartial(), + weight: z.number(), +}); +const zControlNetMetadataItem = zControlField.deepPartial(); +const zIPAdapterMetadataItem = zIPAdapterField.deepPartial(); +const zT2IAdapterMetadataItem = zT2IAdapterField.deepPartial(); +const zSDXLRefinerModelMetadataItem = zSDXLRefinerModelField.deepPartial(); +const zModelMetadataItem = zMainModelField.deepPartial(); +const zVAEModelMetadataItem = zVAEModelField.deepPartial(); +export type LoRAMetadataItem = z.infer; +export type ControlNetMetadataItem = z.infer; +export type IPAdapterMetadataItem = z.infer; +export type T2IAdapterMetadataItem = z.infer; +export type SDXLRefinerModelMetadataItem = z.infer; +export type ModelMetadataItem = z.infer; +export type VAEModelMetadataItem = z.infer; +// #endregion + +// #region CoreMetadata +export const zCoreMetadata = z + .object({ + app_version: z.string().nullish().catch(null), + generation_mode: z.string().nullish().catch(null), + created_by: z.string().nullish().catch(null), + positive_prompt: z.string().nullish().catch(null), + negative_prompt: z.string().nullish().catch(null), + width: z.number().int().nullish().catch(null), + height: z.number().int().nullish().catch(null), + seed: z.number().int().nullish().catch(null), + rand_device: z.string().nullish().catch(null), + cfg_scale: z.number().nullish().catch(null), + cfg_rescale_multiplier: z.number().nullish().catch(null), + steps: z.number().int().nullish().catch(null), + scheduler: z.string().nullish().catch(null), + clip_skip: z.number().int().nullish().catch(null), + model: zModelMetadataItem.nullish().catch(null), + controlnets: z.array(zControlNetMetadataItem).nullish().catch(null), + ipAdapters: z.array(zIPAdapterMetadataItem).nullish().catch(null), + t2iAdapters: z.array(zT2IAdapterMetadataItem).nullish().catch(null), + loras: z.array(zLoRAMetadataItem).nullish().catch(null), + vae: zVAEModelMetadataItem.nullish().catch(null), + strength: z.number().nullish().catch(null), + hrf_enabled: z.boolean().nullish().catch(null), + hrf_strength: z.number().nullish().catch(null), + hrf_method: z.string().nullish().catch(null), + init_image: z.string().nullish().catch(null), + positive_style_prompt: z.string().nullish().catch(null), + negative_style_prompt: z.string().nullish().catch(null), + refiner_model: zSDXLRefinerModelMetadataItem.nullish().catch(null), + refiner_cfg_scale: z.number().nullish().catch(null), + refiner_steps: z.number().int().nullish().catch(null), + refiner_scheduler: z.string().nullish().catch(null), + refiner_positive_aesthetic_score: z.number().nullish().catch(null), + refiner_negative_aesthetic_score: z.number().nullish().catch(null), + refiner_start: z.number().nullish().catch(null), + }) + .passthrough(); +export type CoreMetadata = z.infer; + +// #endregion diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/openapi.ts b/invokeai/frontend/web/src/features/nodes/types/v2/openapi.ts new file mode 100644 index 00000000000..83d774439a3 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/openapi.ts @@ -0,0 +1,86 @@ +import type { OpenAPIV3_1 } from 'openapi-types'; +import type { + InputFieldJSONSchemaExtra, + InvocationJSONSchemaExtra, + OutputFieldJSONSchemaExtra, +} from 'services/api/types'; + +// Janky customization of OpenAPI Schema :/ + +export type InvocationSchemaExtra = InvocationJSONSchemaExtra & { + output: OpenAPIV3_1.ReferenceObject; // the output of the invocation + title: string; + category?: string; + tags?: string[]; + version: string; + properties: Omit< + NonNullable & (InputFieldJSONSchemaExtra | OutputFieldJSONSchemaExtra), + 'type' + > & { + type: Omit & { + default: string; + }; + use_cache: Omit & { + default: boolean; + }; + }; +}; + +export type InvocationSchemaType = { + default: string; // the type of the invocation +}; + +export type InvocationBaseSchemaObject = Omit & + InvocationSchemaExtra; + +export type InvocationOutputSchemaObject = Omit & { + properties: OpenAPIV3_1.SchemaObject['properties'] & { + type: Omit & { + default: string; + }; + } & { + class: 'output'; + }; +}; + +export type InvocationFieldSchema = OpenAPIV3_1.SchemaObject & InputFieldJSONSchemaExtra; + +export type OpenAPIV3_1SchemaOrRef = OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject; + +export interface ArraySchemaObject extends InvocationBaseSchemaObject { + type: OpenAPIV3_1.ArraySchemaObjectType; + items: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject; +} +export interface NonArraySchemaObject extends InvocationBaseSchemaObject { + type?: OpenAPIV3_1.NonArraySchemaObjectType; +} + +export type InvocationSchemaObject = (ArraySchemaObject | NonArraySchemaObject) & { class: 'invocation' }; + +export const isSchemaObject = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined +): obj is OpenAPIV3_1.SchemaObject => Boolean(obj && !('$ref' in obj)); + +export const isArraySchemaObject = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined +): obj is OpenAPIV3_1.ArraySchemaObject => Boolean(obj && !('$ref' in obj) && obj.type === 'array'); + +export const isNonArraySchemaObject = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined +): obj is OpenAPIV3_1.NonArraySchemaObject => Boolean(obj && !('$ref' in obj) && obj.type !== 'array'); + +export const isRefObject = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined +): obj is OpenAPIV3_1.ReferenceObject => Boolean(obj && '$ref' in obj); + +export const isInvocationSchemaObject = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | InvocationSchemaObject +): obj is InvocationSchemaObject => 'class' in obj && obj.class === 'invocation'; + +export const isInvocationOutputSchemaObject = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | InvocationOutputSchemaObject +): obj is InvocationOutputSchemaObject => 'class' in obj && obj.class === 'output'; + +export const isInvocationFieldSchema = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject +): obj is InvocationFieldSchema => !('$ref' in obj); diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/semver.ts b/invokeai/frontend/web/src/features/nodes/types/v2/semver.ts new file mode 100644 index 00000000000..3ba330eac47 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/semver.ts @@ -0,0 +1,21 @@ +import { z } from 'zod'; + +// Schemas and types for working with semver + +const zVersionInt = z.coerce.number().int().min(0); + +export const zSemVer = z.string().refine((val) => { + const [major, minor, patch] = val.split('.'); + return ( + zVersionInt.safeParse(major).success && zVersionInt.safeParse(minor).success && zVersionInt.safeParse(patch).success + ); +}); + +export const zParsedSemver = zSemVer.transform((val) => { + const [major, minor, patch] = val.split('.'); + return { + major: Number(major), + minor: Number(minor), + patch: Number(patch), + }; +}); diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/workflow.ts b/invokeai/frontend/web/src/features/nodes/types/v2/workflow.ts new file mode 100644 index 00000000000..723a354013b --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/workflow.ts @@ -0,0 +1,89 @@ +import { z } from 'zod'; + +import { zFieldIdentifier } from './field'; +import { zInvocationNodeData, zNotesNodeData } from './invocation'; + +// #region Workflow misc +export const zXYPosition = z + .object({ + x: z.number(), + y: z.number(), + }) + .default({ x: 0, y: 0 }); +export type XYPosition = z.infer; + +export const zDimension = z.number().gt(0).nullish(); +export type Dimension = z.infer; + +export const zWorkflowCategory = z.enum(['user', 'default', 'project']); +export type WorkflowCategory = z.infer; +// #endregion + +// #region Workflow Nodes +export const zWorkflowInvocationNode = z.object({ + id: z.string().trim().min(1), + type: z.literal('invocation'), + data: zInvocationNodeData, + width: zDimension, + height: zDimension, + position: zXYPosition, +}); +export const zWorkflowNotesNode = z.object({ + id: z.string().trim().min(1), + type: z.literal('notes'), + data: zNotesNodeData, + width: zDimension, + height: zDimension, + position: zXYPosition, +}); +export const zWorkflowNode = z.union([zWorkflowInvocationNode, zWorkflowNotesNode]); + +export type WorkflowInvocationNode = z.infer; +export type WorkflowNotesNode = z.infer; +export type WorkflowNode = z.infer; + +export const isWorkflowInvocationNode = (val: unknown): val is WorkflowInvocationNode => + zWorkflowInvocationNode.safeParse(val).success; +// #endregion + +// #region Workflow Edges +export const zWorkflowEdgeBase = z.object({ + id: z.string().trim().min(1), + source: z.string().trim().min(1), + target: z.string().trim().min(1), +}); +export const zWorkflowEdgeDefault = zWorkflowEdgeBase.extend({ + type: z.literal('default'), + sourceHandle: z.string().trim().min(1), + targetHandle: z.string().trim().min(1), +}); +export const zWorkflowEdgeCollapsed = zWorkflowEdgeBase.extend({ + type: z.literal('collapsed'), +}); +export const zWorkflowEdge = z.union([zWorkflowEdgeDefault, zWorkflowEdgeCollapsed]); + +export type WorkflowEdgeDefault = z.infer; +export type WorkflowEdgeCollapsed = z.infer; +export type WorkflowEdge = z.infer; +// #endregion + +// #region Workflow +export const zWorkflowV2 = z.object({ + id: z.string().min(1).optional(), + name: z.string(), + author: z.string(), + description: z.string(), + version: z.string(), + contact: z.string(), + tags: z.string(), + notes: z.string(), + nodes: z.array(zWorkflowNode), + edges: z.array(zWorkflowEdge), + exposedFields: z.array(zFieldIdentifier), + meta: z.object({ + category: zWorkflowCategory.default('user'), + version: z.literal('2.0.0'), + }), +}); +export type WorkflowV2 = z.infer; +// #endregion diff --git a/invokeai/frontend/web/src/features/nodes/types/workflow.test-d.ts b/invokeai/frontend/web/src/features/nodes/types/workflow.test-d.ts new file mode 100644 index 00000000000..7cb1ea230ce --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/workflow.test-d.ts @@ -0,0 +1,18 @@ +import type { WorkflowCategory, WorkflowV3, XYPosition } from 'features/nodes/types/workflow'; +import type * as ReactFlow from 'reactflow'; +import type { S } from 'services/api/types'; +import type { Equals, Extends } from 'tsafe'; +import { assert } from 'tsafe'; +import { describe, test } from 'vitest'; + +/** + * These types originate from the server and are recreated as zod schemas manually, for use at runtime. + * The tests ensure that the types are correctly recreated. + */ + +describe('Workflow types', () => { + test('XYPosition', () => assert>()); + test('WorkflowCategory', () => assert>()); + // @ts-expect-error TODO(psyche): Need to revise server types! + test('WorkflowV3', () => assert>()); +}); diff --git a/invokeai/frontend/web/src/features/nodes/types/workflow.ts b/invokeai/frontend/web/src/features/nodes/types/workflow.ts index 723a354013b..adad7c0f219 100644 --- a/invokeai/frontend/web/src/features/nodes/types/workflow.ts +++ b/invokeai/frontend/web/src/features/nodes/types/workflow.ts @@ -24,16 +24,12 @@ export const zWorkflowInvocationNode = z.object({ id: z.string().trim().min(1), type: z.literal('invocation'), data: zInvocationNodeData, - width: zDimension, - height: zDimension, position: zXYPosition, }); export const zWorkflowNotesNode = z.object({ id: z.string().trim().min(1), type: z.literal('notes'), data: zNotesNodeData, - width: zDimension, - height: zDimension, position: zXYPosition, }); export const zWorkflowNode = z.union([zWorkflowInvocationNode, zWorkflowNotesNode]); @@ -68,7 +64,7 @@ export type WorkflowEdge = z.infer; // #endregion // #region Workflow -export const zWorkflowV2 = z.object({ +export const zWorkflowV3 = z.object({ id: z.string().min(1).optional(), name: z.string(), author: z.string(), @@ -82,8 +78,8 @@ export const zWorkflowV2 = z.object({ exposedFields: z.array(zFieldIdentifier), meta: z.object({ category: zWorkflowCategory.default('user'), - version: z.literal('2.0.0'), + version: z.literal('3.0.0'), }), }); -export type WorkflowV2 = z.infer; +export type WorkflowV3 = z.infer; // #endregion diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addHrfToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addHrfToGraph.ts index 7413302fa57..5632cfd1122 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addHrfToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addHrfToGraph.ts @@ -1,6 +1,7 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; import { roundToMultiple } from 'common/util/roundDownToMultiple'; +import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; import { selectOptimalDimension } from 'features/parameters/store/generationSlice'; import type { DenoiseLatentsInvocation, @@ -314,11 +315,16 @@ export const addHrfToGraph = (state: RootState, graph: NonNullableGraph): void = ); copyConnectionsToDenoiseLatentsHrf(graph); + // The original l2i node is unnecessary now, remove it + graph.edges = graph.edges.filter((edge) => edge.destination.node_id !== LATENTS_TO_IMAGE); + delete graph.nodes[LATENTS_TO_IMAGE]; + graph.nodes[LATENTS_TO_IMAGE_HRF_HR] = { type: 'l2i', id: LATENTS_TO_IMAGE_HRF_HR, fp32: originalLatentsToImageNode?.fp32, - is_intermediate: true, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), }; graph.edges.push( { diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addLinearUIOutputNode.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addLinearUIOutputNode.ts deleted file mode 100644 index 5c78ad804ed..00000000000 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addLinearUIOutputNode.ts +++ /dev/null @@ -1,78 +0,0 @@ -import type { RootState } from 'app/store/store'; -import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; -import type { LinearUIOutputInvocation, NonNullableGraph } from 'services/api/types'; - -import { - CANVAS_OUTPUT, - LATENTS_TO_IMAGE, - LATENTS_TO_IMAGE_HRF_HR, - LINEAR_UI_OUTPUT, - NSFW_CHECKER, - WATERMARKER, -} from './constants'; - -/** - * Set the `use_cache` field on the linear/canvas graph's final image output node to False. - */ -export const addLinearUIOutputNode = (state: RootState, graph: NonNullableGraph): void => { - const activeTabName = activeTabNameSelector(state); - const is_intermediate = activeTabName === 'unifiedCanvas' ? !state.canvas.shouldAutoSave : false; - const { autoAddBoardId } = state.gallery; - - const linearUIOutputNode: LinearUIOutputInvocation = { - id: LINEAR_UI_OUTPUT, - type: 'linear_ui_output', - is_intermediate, - use_cache: false, - board: autoAddBoardId === 'none' ? undefined : { board_id: autoAddBoardId }, - }; - - graph.nodes[LINEAR_UI_OUTPUT] = linearUIOutputNode; - - const destination = { - node_id: LINEAR_UI_OUTPUT, - field: 'image', - }; - - if (WATERMARKER in graph.nodes) { - graph.edges.push({ - source: { - node_id: WATERMARKER, - field: 'image', - }, - destination, - }); - } else if (NSFW_CHECKER in graph.nodes) { - graph.edges.push({ - source: { - node_id: NSFW_CHECKER, - field: 'image', - }, - destination, - }); - } else if (CANVAS_OUTPUT in graph.nodes) { - graph.edges.push({ - source: { - node_id: CANVAS_OUTPUT, - field: 'image', - }, - destination, - }); - } else if (LATENTS_TO_IMAGE_HRF_HR in graph.nodes) { - graph.edges.push({ - source: { - node_id: LATENTS_TO_IMAGE_HRF_HR, - field: 'image', - }, - destination, - }); - } else if (LATENTS_TO_IMAGE in graph.nodes) { - graph.edges.push({ - source: { - node_id: LATENTS_TO_IMAGE, - field: 'image', - }, - destination, - }); - } -}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addNSFWCheckerToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addNSFWCheckerToGraph.ts index 4a8e77abfa0..35fc3246890 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addNSFWCheckerToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addNSFWCheckerToGraph.ts @@ -2,6 +2,7 @@ import type { RootState } from 'app/store/store'; import type { ImageNSFWBlurInvocation, LatentsToImageInvocation, NonNullableGraph } from 'services/api/types'; import { LATENTS_TO_IMAGE, NSFW_CHECKER } from './constants'; +import { getBoardField, getIsIntermediate } from './graphBuilderUtils'; export const addNSFWCheckerToGraph = ( state: RootState, @@ -21,7 +22,8 @@ export const addNSFWCheckerToGraph = ( const nsfwCheckerNode: ImageNSFWBlurInvocation = { id: NSFW_CHECKER, type: 'img_nsfw', - is_intermediate: true, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), }; graph.nodes[NSFW_CHECKER] = nsfwCheckerNode as ImageNSFWBlurInvocation; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLRefinerToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLRefinerToGraph.ts index 708353e4d6a..fc4d998969d 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLRefinerToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLRefinerToGraph.ts @@ -24,7 +24,7 @@ import { SDXL_REFINER_POSITIVE_CONDITIONING, SDXL_REFINER_SEAMLESS, } from './constants'; -import { getSDXLStylePrompts } from './getSDXLStylePrompt'; +import { getSDXLStylePrompts } from './graphBuilderUtils'; import { upsertMetadata } from './metadata'; export const addSDXLRefinerToGraph = ( diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addWatermarkerToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addWatermarkerToGraph.ts index 99c5c07be47..61beb11df49 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addWatermarkerToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addWatermarkerToGraph.ts @@ -1,5 +1,4 @@ import type { RootState } from 'app/store/store'; -import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import type { ImageNSFWBlurInvocation, ImageWatermarkInvocation, @@ -8,16 +7,13 @@ import type { } from 'services/api/types'; import { LATENTS_TO_IMAGE, NSFW_CHECKER, WATERMARKER } from './constants'; +import { getBoardField, getIsIntermediate } from './graphBuilderUtils'; export const addWatermarkerToGraph = ( state: RootState, graph: NonNullableGraph, nodeIdToAddTo = LATENTS_TO_IMAGE ): void => { - const activeTabName = activeTabNameSelector(state); - - const is_intermediate = activeTabName === 'unifiedCanvas' ? !state.canvas.shouldAutoSave : false; - const nodeToAddTo = graph.nodes[nodeIdToAddTo] as LatentsToImageInvocation | undefined; const nsfwCheckerNode = graph.nodes[NSFW_CHECKER] as ImageNSFWBlurInvocation | undefined; @@ -30,7 +26,8 @@ export const addWatermarkerToGraph = ( const watermarkerNode: ImageWatermarkInvocation = { id: WATERMARKER, type: 'img_watermark', - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), }; graph.nodes[WATERMARKER] = watermarkerNode; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildAdHocUpscaleGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildAdHocUpscaleGraph.ts index fa20206d91f..52c09b1db06 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildAdHocUpscaleGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildAdHocUpscaleGraph.ts @@ -1,51 +1,33 @@ -import type { BoardId } from 'features/gallery/store/types'; -import type { ParamESRGANModelName } from 'features/parameters/store/postprocessingSlice'; -import type { ESRGANInvocation, Graph, LinearUIOutputInvocation, NonNullableGraph } from 'services/api/types'; +import type { RootState } from 'app/store/store'; +import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; +import type { ESRGANInvocation, Graph, NonNullableGraph } from 'services/api/types'; -import { ESRGAN, LINEAR_UI_OUTPUT } from './constants'; +import { ESRGAN } from './constants'; import { addCoreMetadataNode, upsertMetadata } from './metadata'; type Arg = { image_name: string; - esrganModelName: ParamESRGANModelName; - autoAddBoardId: BoardId; + state: RootState; }; -export const buildAdHocUpscaleGraph = ({ image_name, esrganModelName, autoAddBoardId }: Arg): Graph => { +export const buildAdHocUpscaleGraph = ({ image_name, state }: Arg): Graph => { + const { esrganModelName } = state.postprocessing; + const realesrganNode: ESRGANInvocation = { id: ESRGAN, type: 'esrgan', image: { image_name }, model_name: esrganModelName, - is_intermediate: true, - }; - - const linearUIOutputNode: LinearUIOutputInvocation = { - id: LINEAR_UI_OUTPUT, - type: 'linear_ui_output', - use_cache: false, - is_intermediate: false, - board: autoAddBoardId === 'none' ? undefined : { board_id: autoAddBoardId }, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), }; const graph: NonNullableGraph = { id: `adhoc-esrgan-graph`, nodes: { [ESRGAN]: realesrganNode, - [LINEAR_UI_OUTPUT]: linearUIOutputNode, }, - edges: [ - { - source: { - node_id: ESRGAN, - field: 'image', - }, - destination: { - node_id: LINEAR_UI_OUTPUT, - field: 'image', - }, - }, - ], + edges: [], }; addCoreMetadataNode(graph, {}, ESRGAN); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasImageToImageGraph.ts index 3002e05441b..bc6a83f4fa6 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasImageToImageGraph.ts @@ -1,10 +1,10 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; +import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; import type { ImageDTO, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; @@ -132,7 +132,8 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima [CANVAS_OUTPUT]: { type: 'l2i', id: CANVAS_OUTPUT, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), use_cache: false, }, }, @@ -242,7 +243,8 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima graph.nodes[CANVAS_OUTPUT] = { id: CANVAS_OUTPUT, type: 'img_resize', - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), width: width, height: height, use_cache: false, @@ -284,7 +286,8 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima graph.nodes[CANVAS_OUTPUT] = { type: 'l2i', id: CANVAS_OUTPUT, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), fp32, use_cache: false, }; @@ -355,7 +358,5 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima addWatermarkerToGraph(state, graph, CANVAS_OUTPUT); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasInpaintGraph.ts index bb52a44a8e4..d983b9cf4f5 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasInpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasInpaintGraph.ts @@ -1,5 +1,6 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; +import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; import type { CreateDenoiseMaskInvocation, ImageBlurInvocation, @@ -12,7 +13,6 @@ import type { import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; @@ -191,7 +191,8 @@ export const buildCanvasInpaintGraph = ( [CANVAS_OUTPUT]: { type: 'color_correct', id: CANVAS_OUTPUT, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), reference: canvasInitImage, use_cache: false, }, @@ -663,7 +664,5 @@ export const buildCanvasInpaintGraph = ( addWatermarkerToGraph(state, graph, CANVAS_OUTPUT); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasOutpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasOutpaintGraph.ts index b82b55cfee7..1d028943818 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasOutpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasOutpaintGraph.ts @@ -1,5 +1,6 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; +import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; import type { ImageDTO, ImageToLatentsInvocation, @@ -11,7 +12,6 @@ import type { import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; @@ -200,7 +200,8 @@ export const buildCanvasOutpaintGraph = ( [CANVAS_OUTPUT]: { type: 'color_correct', id: CANVAS_OUTPUT, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), use_cache: false, }, }, @@ -769,7 +770,5 @@ export const buildCanvasOutpaintGraph = ( addWatermarkerToGraph(state, graph, CANVAS_OUTPUT); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLImageToImageGraph.ts index 1b586371a02..58269afce3f 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLImageToImageGraph.ts @@ -4,7 +4,6 @@ import type { ImageDTO, ImageToLatentsInvocation, NonNullableGraph } from 'servi import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; @@ -26,7 +25,7 @@ import { SDXL_REFINER_SEAMLESS, SEAMLESS, } from './constants'; -import { getSDXLStylePrompts } from './getSDXLStylePrompt'; +import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils'; import { addCoreMetadataNode } from './metadata'; /** @@ -246,7 +245,8 @@ export const buildCanvasSDXLImageToImageGraph = (state: RootState, initialImage: graph.nodes[CANVAS_OUTPUT] = { id: CANVAS_OUTPUT, type: 'img_resize', - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), width: width, height: height, use_cache: false, @@ -368,7 +368,5 @@ export const buildCanvasSDXLImageToImageGraph = (state: RootState, initialImage: addWatermarkerToGraph(state, graph, CANVAS_OUTPUT); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLInpaintGraph.ts index 00fea9a37e6..5902dee2fc4 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLInpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLInpaintGraph.ts @@ -12,7 +12,6 @@ import type { import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; @@ -44,7 +43,7 @@ import { SDXL_REFINER_SEAMLESS, SEAMLESS, } from './constants'; -import { getSDXLStylePrompts } from './getSDXLStylePrompt'; +import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils'; /** * Builds the Canvas tab's Inpaint graph. @@ -190,7 +189,8 @@ export const buildCanvasSDXLInpaintGraph = ( [CANVAS_OUTPUT]: { type: 'color_correct', id: CANVAS_OUTPUT, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), reference: canvasInitImage, use_cache: false, }, @@ -687,7 +687,5 @@ export const buildCanvasSDXLInpaintGraph = ( addWatermarkerToGraph(state, graph, CANVAS_OUTPUT); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLOutpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLOutpaintGraph.ts index f85760d8f2e..7a78750e8d2 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLOutpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLOutpaintGraph.ts @@ -11,7 +11,6 @@ import type { import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; @@ -46,7 +45,7 @@ import { SDXL_REFINER_SEAMLESS, SEAMLESS, } from './constants'; -import { getSDXLStylePrompts } from './getSDXLStylePrompt'; +import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils'; /** * Builds the Canvas tab's Outpaint graph. @@ -199,7 +198,8 @@ export const buildCanvasSDXLOutpaintGraph = ( [CANVAS_OUTPUT]: { type: 'color_correct', id: CANVAS_OUTPUT, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), use_cache: false, }, }, @@ -786,7 +786,5 @@ export const buildCanvasSDXLOutpaintGraph = ( addWatermarkerToGraph(state, graph, CANVAS_OUTPUT); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLTextToImageGraph.ts index 91d9da4cb5d..22da39c67da 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLTextToImageGraph.ts @@ -4,7 +4,6 @@ import type { NonNullableGraph } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; @@ -24,7 +23,7 @@ import { SDXL_REFINER_SEAMLESS, SEAMLESS, } from './constants'; -import { getSDXLStylePrompts } from './getSDXLStylePrompt'; +import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils'; import { addCoreMetadataNode } from './metadata'; /** @@ -222,7 +221,8 @@ export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGr graph.nodes[CANVAS_OUTPUT] = { id: CANVAS_OUTPUT, type: 'img_resize', - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), width: width, height: height, use_cache: false, @@ -254,7 +254,8 @@ export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGr graph.nodes[CANVAS_OUTPUT] = { type: 'l2i', id: CANVAS_OUTPUT, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), fp32, use_cache: false, }; @@ -330,7 +331,5 @@ export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGr addWatermarkerToGraph(state, graph, CANVAS_OUTPUT); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasTextToImageGraph.ts index 967dd3ff4a5..93f0470c7ad 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasTextToImageGraph.ts @@ -1,10 +1,10 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; +import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; import type { NonNullableGraph } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; @@ -211,7 +211,8 @@ export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph graph.nodes[CANVAS_OUTPUT] = { id: CANVAS_OUTPUT, type: 'img_resize', - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), width: width, height: height, use_cache: false, @@ -243,7 +244,8 @@ export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph graph.nodes[CANVAS_OUTPUT] = { type: 'l2i', id: CANVAS_OUTPUT, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), fp32, use_cache: false, }; @@ -310,7 +312,5 @@ export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph addWatermarkerToGraph(state, graph, CANVAS_OUTPUT); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearImageToImageGraph.ts index c76776d94d3..d1f1546b23b 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearImageToImageGraph.ts @@ -1,10 +1,10 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; +import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; import type { ImageResizeInvocation, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; @@ -117,7 +117,8 @@ export const buildLinearImageToImageGraph = (state: RootState): NonNullableGraph type: 'l2i', id: LATENTS_TO_IMAGE, fp32, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), }, [DENOISE_LATENTS]: { type: 'denoise_latents', @@ -358,7 +359,5 @@ export const buildLinearImageToImageGraph = (state: RootState): NonNullableGraph addWatermarkerToGraph(state, graph); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLImageToImageGraph.ts index 9ae602bcacb..de4ad7ceceb 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLImageToImageGraph.ts @@ -4,7 +4,6 @@ import type { ImageResizeInvocation, ImageToLatentsInvocation, NonNullableGraph import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; @@ -25,7 +24,7 @@ import { SDXL_REFINER_SEAMLESS, SEAMLESS, } from './constants'; -import { getSDXLStylePrompts } from './getSDXLStylePrompt'; +import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils'; import { addCoreMetadataNode } from './metadata'; /** @@ -120,7 +119,8 @@ export const buildLinearSDXLImageToImageGraph = (state: RootState): NonNullableG type: 'l2i', id: LATENTS_TO_IMAGE, fp32, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), }, [SDXL_DENOISE_LATENTS]: { type: 'denoise_latents', @@ -380,7 +380,5 @@ export const buildLinearSDXLImageToImageGraph = (state: RootState): NonNullableG addWatermarkerToGraph(state, graph); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLTextToImageGraph.ts index 222dc1a3595..58b97b07c75 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLTextToImageGraph.ts @@ -4,7 +4,6 @@ import type { NonNullableGraph } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; @@ -23,7 +22,7 @@ import { SDXL_TEXT_TO_IMAGE_GRAPH, SEAMLESS, } from './constants'; -import { getSDXLStylePrompts } from './getSDXLStylePrompt'; +import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils'; import { addCoreMetadataNode } from './metadata'; export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGraph => { @@ -120,7 +119,8 @@ export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGr type: 'l2i', id: LATENTS_TO_IMAGE, fp32, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), use_cache: false, }, }, @@ -281,7 +281,5 @@ export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGr addWatermarkerToGraph(state, graph); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearTextToImageGraph.ts index 0a45d91debc..b2b84cfdad7 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearTextToImageGraph.ts @@ -1,11 +1,11 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; +import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; import type { NonNullableGraph } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addHrfToGraph } from './addHrfToGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; @@ -119,7 +119,8 @@ export const buildLinearTextToImageGraph = (state: RootState): NonNullableGraph type: 'l2i', id: LATENTS_TO_IMAGE, fp32, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), use_cache: false, }, }, @@ -267,7 +268,5 @@ export const buildLinearTextToImageGraph = (state: RootState): NonNullableGraph addWatermarkerToGraph(state, graph); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/constants.ts b/invokeai/frontend/web/src/features/nodes/util/graph/constants.ts index 363d3191210..767bf25df0a 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/constants.ts @@ -9,7 +9,6 @@ export const LATENTS_TO_IMAGE_HRF_LR = 'latents_to_image_hrf_lr'; export const IMAGE_TO_LATENTS_HRF = 'image_to_latents_hrf'; export const RESIZE_HRF = 'resize_hrf'; export const ESRGAN_HRF = 'esrgan_hrf'; -export const LINEAR_UI_OUTPUT = 'linear_ui_output'; export const NSFW_CHECKER = 'nsfw_checker'; export const WATERMARKER = 'invisible_watermark'; export const NOISE = 'noise'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/getSDXLStylePrompt.ts b/invokeai/frontend/web/src/features/nodes/util/graph/getSDXLStylePrompt.ts deleted file mode 100644 index e1cd8518fdd..00000000000 --- a/invokeai/frontend/web/src/features/nodes/util/graph/getSDXLStylePrompt.ts +++ /dev/null @@ -1,11 +0,0 @@ -import type { RootState } from 'app/store/store'; - -export const getSDXLStylePrompts = (state: RootState): { positiveStylePrompt: string; negativeStylePrompt: string } => { - const { positivePrompt, negativePrompt } = state.generation; - const { positiveStylePrompt, negativeStylePrompt, shouldConcatSDXLStylePrompt } = state.sdxl; - - return { - positiveStylePrompt: shouldConcatSDXLStylePrompt ? positivePrompt : positiveStylePrompt, - negativeStylePrompt: shouldConcatSDXLStylePrompt ? negativePrompt : negativeStylePrompt, - }; -}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/graphBuilderUtils.ts b/invokeai/frontend/web/src/features/nodes/util/graph/graphBuilderUtils.ts new file mode 100644 index 00000000000..cb6fc9acf1e --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graph/graphBuilderUtils.ts @@ -0,0 +1,38 @@ +import type { RootState } from 'app/store/store'; +import type { BoardField } from 'features/nodes/types/common'; +import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; + +/** + * Gets the board field, based on the autoAddBoardId setting. + */ +export const getBoardField = (state: RootState): BoardField | undefined => { + const { autoAddBoardId } = state.gallery; + if (autoAddBoardId === 'none') { + return undefined; + } + return { board_id: autoAddBoardId }; +}; + +/** + * Gets the SDXL style prompts, based on the concat setting. + */ +export const getSDXLStylePrompts = (state: RootState): { positiveStylePrompt: string; negativeStylePrompt: string } => { + const { positivePrompt, negativePrompt } = state.generation; + const { positiveStylePrompt, negativeStylePrompt, shouldConcatSDXLStylePrompt } = state.sdxl; + + return { + positiveStylePrompt: shouldConcatSDXLStylePrompt ? positivePrompt : positiveStylePrompt, + negativeStylePrompt: shouldConcatSDXLStylePrompt ? negativePrompt : negativeStylePrompt, + }; +}; + +/** + * Gets the is_intermediate field, based on the active tab and shouldAutoSave setting. + */ +export const getIsIntermediate = (state: RootState) => { + const activeTabName = activeTabNameSelector(state); + if (activeTabName === 'unifiedCanvas') { + return !state.canvas.shouldAutoSave; + } + return false; +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/node/buildInvocationNode.ts b/invokeai/frontend/web/src/features/nodes/util/node/buildInvocationNode.ts index ea40bd4660f..af19aa86eaf 100644 --- a/invokeai/frontend/web/src/features/nodes/util/node/buildInvocationNode.ts +++ b/invokeai/frontend/web/src/features/nodes/util/node/buildInvocationNode.ts @@ -1,5 +1,5 @@ import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants'; -import type { FieldInputInstance, FieldOutputInstance } from 'features/nodes/types/field'; +import type { FieldInputInstance } from 'features/nodes/types/field'; import type { InvocationNode, InvocationTemplate } from 'features/nodes/types/invocation'; import { buildFieldInputInstance } from 'features/nodes/util/schema/buildFieldInputInstance'; import { reduce } from 'lodash-es'; @@ -24,25 +24,6 @@ export const buildInvocationNode = (position: XYPosition, template: InvocationTe {} as Record ); - const outputs = reduce( - template.outputs, - (outputsAccumulator, outputTemplate, outputName) => { - const fieldId = uuidv4(); - - const outputFieldValue: FieldOutputInstance = { - id: fieldId, - name: outputName, - type: outputTemplate.type, - fieldKind: 'output', - }; - - outputsAccumulator[outputName] = outputFieldValue; - - return outputsAccumulator; - }, - {} as Record - ); - const node: InvocationNode = { ...SHARED_NODE_PROPERTIES, id: nodeId, @@ -58,7 +39,6 @@ export const buildInvocationNode = (position: XYPosition, template: InvocationTe isIntermediate: type === 'save_image' ? false : true, useCache: template.useCache, inputs, - outputs, }, }; diff --git a/invokeai/frontend/web/src/features/nodes/util/node/nodeUpdate.ts b/invokeai/frontend/web/src/features/nodes/util/node/nodeUpdate.ts index f195c49d30a..5ece51d0f30 100644 --- a/invokeai/frontend/web/src/features/nodes/util/node/nodeUpdate.ts +++ b/invokeai/frontend/web/src/features/nodes/util/node/nodeUpdate.ts @@ -54,6 +54,5 @@ export const updateNode = (node: InvocationNode, template: InvocationTemplate): // Remove any fields that are not in the template clone.data.inputs = pick(clone.data.inputs, keys(defaults.data.inputs)); - clone.data.outputs = pick(clone.data.outputs, keys(defaults.data.outputs)); return clone; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts index dd3cf0ad7b4..f8097566c95 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts @@ -23,11 +23,8 @@ const FIELD_VALUE_FALLBACK_MAP: Record = export const buildFieldInputInstance = (id: string, template: FieldInputTemplate): FieldInputInstance => { const fieldInstance: FieldInputInstance = { - id, name: template.name, - type: template.type, label: '', - fieldKind: 'input' as const, value: template.default ?? get(FIELD_VALUE_FALLBACK_MAP, template.type.name), }; diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.test.ts b/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.test.ts new file mode 100644 index 00000000000..d7011ad6f84 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.test.ts @@ -0,0 +1,379 @@ +import { + UnableToExtractSchemaNameFromRefError, + UnsupportedArrayItemType, + UnsupportedPrimitiveTypeError, + UnsupportedUnionError, +} from 'features/nodes/types/error'; +import type { InvocationFieldSchema, OpenAPIV3_1SchemaOrRef } from 'features/nodes/types/openapi'; +import { parseFieldType, refObjectToSchemaName } from 'features/nodes/util/schema/parseFieldType'; +import { describe, expect, it } from 'vitest'; + +type ParseFieldTypeTestCase = { + name: string; + schema: OpenAPIV3_1SchemaOrRef | InvocationFieldSchema; + expected: { name: string; isCollection: boolean; isCollectionOrScalar: boolean }; +}; + +const primitiveTypes: ParseFieldTypeTestCase[] = [ + { + name: 'Scalar IntegerField', + schema: { type: 'integer' }, + expected: { name: 'IntegerField', isCollection: false, isCollectionOrScalar: false }, + }, + { + name: 'Scalar FloatField', + schema: { type: 'number' }, + expected: { name: 'FloatField', isCollection: false, isCollectionOrScalar: false }, + }, + { + name: 'Scalar StringField', + schema: { type: 'string' }, + expected: { name: 'StringField', isCollection: false, isCollectionOrScalar: false }, + }, + { + name: 'Scalar BooleanField', + schema: { type: 'boolean' }, + expected: { name: 'BooleanField', isCollection: false, isCollectionOrScalar: false }, + }, + { + name: 'Collection IntegerField', + schema: { items: { type: 'integer' }, type: 'array' }, + expected: { name: 'IntegerField', isCollection: true, isCollectionOrScalar: false }, + }, + { + name: 'Collection FloatField', + schema: { items: { type: 'number' }, type: 'array' }, + expected: { name: 'FloatField', isCollection: true, isCollectionOrScalar: false }, + }, + { + name: 'Collection StringField', + schema: { items: { type: 'string' }, type: 'array' }, + expected: { name: 'StringField', isCollection: true, isCollectionOrScalar: false }, + }, + { + name: 'Collection BooleanField', + schema: { items: { type: 'boolean' }, type: 'array' }, + expected: { name: 'BooleanField', isCollection: true, isCollectionOrScalar: false }, + }, + { + name: 'CollectionOrScalar IntegerField', + schema: { + anyOf: [ + { + type: 'integer', + }, + { + items: { + type: 'integer', + }, + type: 'array', + }, + ], + }, + expected: { name: 'IntegerField', isCollection: false, isCollectionOrScalar: true }, + }, + { + name: 'CollectionOrScalar FloatField', + schema: { + anyOf: [ + { + type: 'number', + }, + { + items: { + type: 'number', + }, + type: 'array', + }, + ], + }, + expected: { name: 'FloatField', isCollection: false, isCollectionOrScalar: true }, + }, + { + name: 'CollectionOrScalar StringField', + schema: { + anyOf: [ + { + type: 'string', + }, + { + items: { + type: 'string', + }, + type: 'array', + }, + ], + }, + expected: { name: 'StringField', isCollection: false, isCollectionOrScalar: true }, + }, + { + name: 'CollectionOrScalar BooleanField', + schema: { + anyOf: [ + { + type: 'boolean', + }, + { + items: { + type: 'boolean', + }, + type: 'array', + }, + ], + }, + expected: { name: 'BooleanField', isCollection: false, isCollectionOrScalar: true }, + }, +]; + +const complexTypes: ParseFieldTypeTestCase[] = [ + { + name: 'Scalar ConditioningField', + schema: { + allOf: [ + { + $ref: '#/components/schemas/ConditioningField', + }, + ], + }, + expected: { name: 'ConditioningField', isCollection: false, isCollectionOrScalar: false }, + }, + { + name: 'Nullable Scalar ConditioningField', + schema: { + anyOf: [ + { + $ref: '#/components/schemas/ConditioningField', + }, + { + type: 'null', + }, + ], + }, + expected: { name: 'ConditioningField', isCollection: false, isCollectionOrScalar: false }, + }, + { + name: 'Collection ConditioningField', + schema: { + anyOf: [ + { + items: { + $ref: '#/components/schemas/ConditioningField', + }, + type: 'array', + }, + ], + }, + expected: { name: 'ConditioningField', isCollection: true, isCollectionOrScalar: false }, + }, + { + name: 'Nullable Collection ConditioningField', + schema: { + anyOf: [ + { + items: { + $ref: '#/components/schemas/ConditioningField', + }, + type: 'array', + }, + { + type: 'null', + }, + ], + }, + expected: { name: 'ConditioningField', isCollection: true, isCollectionOrScalar: false }, + }, + { + name: 'CollectionOrScalar ConditioningField', + schema: { + anyOf: [ + { + items: { + $ref: '#/components/schemas/ConditioningField', + }, + type: 'array', + }, + { + $ref: '#/components/schemas/ConditioningField', + }, + ], + }, + expected: { name: 'ConditioningField', isCollection: false, isCollectionOrScalar: true }, + }, + { + name: 'Nullable CollectionOrScalar ConditioningField', + schema: { + anyOf: [ + { + items: { + $ref: '#/components/schemas/ConditioningField', + }, + type: 'array', + }, + { + $ref: '#/components/schemas/ConditioningField', + }, + { + type: 'null', + }, + ], + }, + expected: { name: 'ConditioningField', isCollection: false, isCollectionOrScalar: true }, + }, +]; + +const specialCases: ParseFieldTypeTestCase[] = [ + { + name: 'String EnumField', + schema: { + type: 'string', + enum: ['large', 'base', 'small'], + }, + expected: { name: 'EnumField', isCollection: false, isCollectionOrScalar: false }, + }, + { + name: 'String EnumField with one value', + schema: { + const: 'Some Value', + }, + expected: { name: 'EnumField', isCollection: false, isCollectionOrScalar: false }, + }, + { + name: 'Explicit ui_type (SchedulerField)', + schema: { + type: 'string', + enum: ['ddim', 'ddpm', 'deis'], + ui_type: 'SchedulerField', + }, + expected: { name: 'SchedulerField', isCollection: false, isCollectionOrScalar: false }, + }, + { + name: 'Explicit ui_type (AnyField)', + schema: { + type: 'string', + enum: ['ddim', 'ddpm', 'deis'], + ui_type: 'AnyField', + }, + expected: { name: 'AnyField', isCollection: false, isCollectionOrScalar: false }, + }, + { + name: 'Explicit ui_type (CollectionField)', + schema: { + type: 'string', + enum: ['ddim', 'ddpm', 'deis'], + ui_type: 'CollectionField', + }, + expected: { name: 'CollectionField', isCollection: true, isCollectionOrScalar: false }, + }, +]; + +describe('refObjectToSchemaName', async () => { + it('parses ref object 1', () => { + expect( + refObjectToSchemaName({ + $ref: '#/components/schemas/ImageField', + }) + ).toEqual('ImageField'); + }); + it('parses ref object 2', () => { + expect( + refObjectToSchemaName({ + $ref: '#/components/schemas/T2IAdapterModelField', + }) + ).toEqual('T2IAdapterModelField'); + }); +}); + +describe.concurrent('parseFieldType', async () => { + it.each(primitiveTypes)('parses primitive types ($name)', ({ schema, expected }: ParseFieldTypeTestCase) => { + expect(parseFieldType(schema)).toEqual(expected); + }); + it.each(complexTypes)('parses complex types ($name)', ({ schema, expected }: ParseFieldTypeTestCase) => { + expect(parseFieldType(schema)).toEqual(expected); + }); + it.each(specialCases)('parses special case types ($name)', ({ schema, expected }: ParseFieldTypeTestCase) => { + expect(parseFieldType(schema)).toEqual(expected); + }); + + it('raises if it cannot extract a schema name from a ref', () => { + expect(() => + parseFieldType({ + allOf: [ + { + $ref: '#/components/schemas/', + }, + ], + }) + ).toThrowError(UnableToExtractSchemaNameFromRefError); + }); + + it('raises if it receives a union of mismatched types', () => { + expect(() => + parseFieldType({ + anyOf: [ + { + type: 'string', + }, + { + type: 'integer', + }, + ], + }) + ).toThrowError(UnsupportedUnionError); + }); + + it('raises if it receives a union of mismatched types (excluding null)', () => { + expect(() => + parseFieldType({ + anyOf: [ + { + type: 'string', + }, + { + type: 'integer', + }, + { + type: 'null', + }, + ], + }) + ).toThrowError(UnsupportedUnionError); + }); + + it('raises if it received an unsupported primitive type (object)', () => { + expect(() => + parseFieldType({ + type: 'object', + }) + ).toThrowError(UnsupportedPrimitiveTypeError); + }); + + it('raises if it received an unsupported primitive type (null)', () => { + expect(() => + parseFieldType({ + type: 'null', + }) + ).toThrowError(UnsupportedPrimitiveTypeError); + }); + + it('raises if it received an unsupported array item type (object)', () => { + expect(() => + parseFieldType({ + items: { + type: 'object', + }, + type: 'array', + }) + ).toThrowError(UnsupportedArrayItemType); + }); + + it('raises if it received an unsupported array item type (null)', () => { + expect(() => + parseFieldType({ + items: { + type: 'null', + }, + type: 'array', + }) + ).toThrowError(UnsupportedArrayItemType); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.ts b/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.ts index 14b1aefd6d3..13da6b38312 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.ts @@ -1,6 +1,12 @@ -import { FieldParseError } from 'features/nodes/types/error'; +import { + FieldParseError, + UnableToExtractSchemaNameFromRefError, + UnsupportedArrayItemType, + UnsupportedPrimitiveTypeError, + UnsupportedUnionError, +} from 'features/nodes/types/error'; import type { FieldType } from 'features/nodes/types/field'; -import type { OpenAPIV3_1SchemaOrRef } from 'features/nodes/types/openapi'; +import type { InvocationFieldSchema, OpenAPIV3_1SchemaOrRef } from 'features/nodes/types/openapi'; import { isArraySchemaObject, isInvocationFieldSchema, @@ -42,7 +48,7 @@ const isCollectionFieldType = (fieldType: string) => { return false; }; -export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType => { +export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef | InvocationFieldSchema): FieldType => { if (isInvocationFieldSchema(schemaObject)) { // Check if this field has an explicit type provided by the node schema const { ui_type } = schemaObject; @@ -72,7 +78,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType // This is a single ref type const name = refObjectToSchemaName(allOf[0]); if (!name) { - throw new FieldParseError(t('nodes.unableToExtractSchemaNameFromRef')); + throw new UnableToExtractSchemaNameFromRefError(t('nodes.unableToExtractSchemaNameFromRef')); } return { name, @@ -95,7 +101,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType if (isRefObject(filteredAnyOf[0])) { const name = refObjectToSchemaName(filteredAnyOf[0]); if (!name) { - throw new FieldParseError(t('nodes.unableToExtractSchemaNameFromRef')); + throw new UnableToExtractSchemaNameFromRefError(t('nodes.unableToExtractSchemaNameFromRef')); } return { @@ -118,7 +124,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType if (filteredAnyOf.length !== 2) { // This is a union of more than 2 types, which we don't support - throw new FieldParseError( + throw new UnsupportedUnionError( t('nodes.unsupportedAnyOfLength', { count: filteredAnyOf.length, }) @@ -159,7 +165,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType }; } - throw new FieldParseError( + throw new UnsupportedUnionError( t('nodes.unsupportedMismatchedUnion', { firstType, secondType, @@ -178,7 +184,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType if (isSchemaObject(schemaObject.items)) { const itemType = schemaObject.items.type; if (!itemType || isArray(itemType)) { - throw new FieldParseError( + throw new UnsupportedArrayItemType( t('nodes.unsupportedArrayItemType', { type: itemType, }) @@ -188,7 +194,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType const name = OPENAPI_TO_FIELD_TYPE_MAP[itemType]; if (!name) { // it's 'null', 'object', or 'array' - skip - throw new FieldParseError( + throw new UnsupportedArrayItemType( t('nodes.unsupportedArrayItemType', { type: itemType, }) @@ -204,7 +210,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType // This is a ref object, extract the type name const name = refObjectToSchemaName(schemaObject.items); if (!name) { - throw new FieldParseError(t('nodes.unableToExtractSchemaNameFromRef')); + throw new UnableToExtractSchemaNameFromRefError(t('nodes.unableToExtractSchemaNameFromRef')); } return { name, @@ -216,7 +222,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType const name = OPENAPI_TO_FIELD_TYPE_MAP[schemaObject.type]; if (!name) { // it's 'null', 'object', or 'array' - skip - throw new FieldParseError( + throw new UnsupportedPrimitiveTypeError( t('nodes.unsupportedArrayItemType', { type: schemaObject.type, }) @@ -232,7 +238,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType } else if (isRefObject(schemaObject)) { const name = refObjectToSchemaName(schemaObject); if (!name) { - throw new FieldParseError(t('nodes.unableToExtractSchemaNameFromRef')); + throw new UnableToExtractSchemaNameFromRefError(t('nodes.unableToExtractSchemaNameFromRef')); } return { name, diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts index 70775a98823..720da164648 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts @@ -2,8 +2,8 @@ import { logger } from 'app/logging/logger'; import { parseify } from 'common/util/serialize'; import type { NodesState, WorkflowsState } from 'features/nodes/store/types'; import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation'; -import type { WorkflowV2 } from 'features/nodes/types/workflow'; -import { zWorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowV3 } from 'features/nodes/types/workflow'; +import { zWorkflowV3 } from 'features/nodes/types/workflow'; import i18n from 'i18n'; import { cloneDeep, pick } from 'lodash-es'; import { fromZodError } from 'zod-validation-error'; @@ -25,14 +25,14 @@ const workflowKeys = [ 'exposedFields', 'meta', 'id', -] satisfies (keyof WorkflowV2)[]; +] satisfies (keyof WorkflowV3)[]; -export type BuildWorkflowFunction = (arg: BuildWorkflowArg) => WorkflowV2; +export type BuildWorkflowFunction = (arg: BuildWorkflowArg) => WorkflowV3; -export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflow }: BuildWorkflowArg): WorkflowV2 => { +export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflow }: BuildWorkflowArg): WorkflowV3 => { const clonedWorkflow = pick(cloneDeep(workflow), workflowKeys); - const newWorkflow: WorkflowV2 = { + const newWorkflow: WorkflowV3 = { ...clonedWorkflow, nodes: [], edges: [], @@ -45,8 +45,6 @@ export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflo type: node.type, data: cloneDeep(node.data), position: { ...node.position }, - width: node.width, - height: node.height, }); } else if (isNotesNode(node) && node.type) { newWorkflow.nodes.push({ @@ -54,8 +52,6 @@ export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflo type: node.type, data: cloneDeep(node.data), position: { ...node.position }, - width: node.width, - height: node.height, }); } }); @@ -83,12 +79,12 @@ export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflo return newWorkflow; }; -export const buildWorkflowWithValidation = ({ nodes, edges, workflow }: BuildWorkflowArg): WorkflowV2 | null => { +export const buildWorkflowWithValidation = ({ nodes, edges, workflow }: BuildWorkflowArg): WorkflowV3 | null => { // builds what really, really should be a valid workflow const workflowToValidate = buildWorkflowFast({ nodes, edges, workflow }); // but bc we are storing this in the DB, let's be extra sure - const result = zWorkflowV2.safeParse(workflowToValidate); + const result = zWorkflowV3.safeParse(workflowToValidate); if (!result.success) { const { message } = fromZodError(result.error, { diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts index a2677f3d174..a023c96ba92 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts @@ -6,8 +6,10 @@ import { zSemVer } from 'features/nodes/types/semver'; import { FIELD_TYPE_V1_TO_FIELD_TYPE_V2_MAPPING } from 'features/nodes/types/v1/fieldTypeMap'; import type { WorkflowV1 } from 'features/nodes/types/v1/workflowV1'; import { zWorkflowV1 } from 'features/nodes/types/v1/workflowV1'; -import type { WorkflowV2 } from 'features/nodes/types/workflow'; -import { zWorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowV2 } from 'features/nodes/types/v2/workflow'; +import { zWorkflowV2 } from 'features/nodes/types/v2/workflow'; +import type { WorkflowV3 } from 'features/nodes/types/workflow'; +import { zWorkflowV3 } from 'features/nodes/types/workflow'; import { t } from 'i18next'; import { forEach } from 'lodash-es'; import { z } from 'zod'; @@ -30,7 +32,7 @@ const zWorkflowMetaVersion = z.object({ * - Workflow schema version bumped to 2.0.0 */ const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => { - const invocationTemplates = $store.get()?.getState().nodeTemplates.templates; + const invocationTemplates = $store.get()?.getState().nodes.templates; if (!invocationTemplates) { throw new Error(t('app.storeNotInitialized')); @@ -70,26 +72,34 @@ const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => { return zWorkflowV2.parse(workflowToMigrate); }; +const migrateV2toV3 = (workflowToMigrate: WorkflowV2): WorkflowV3 => { + // Bump version + (workflowToMigrate as unknown as WorkflowV3).meta.version = '3.0.0'; + // Parsing strips out any extra properties not in the latest version + return zWorkflowV3.parse(workflowToMigrate); +}; + /** * Parses a workflow and migrates it to the latest version if necessary. */ -export const parseAndMigrateWorkflow = (data: unknown): WorkflowV2 => { +export const parseAndMigrateWorkflow = (data: unknown): WorkflowV3 => { const workflowVersionResult = zWorkflowMetaVersion.safeParse(data); if (!workflowVersionResult.success) { throw new WorkflowVersionError(t('nodes.unableToGetWorkflowVersion')); } - const { version } = workflowVersionResult.data.meta; + let workflow = data as WorkflowV1 | WorkflowV2 | WorkflowV3; - if (version === '1.0.0') { - const v1 = zWorkflowV1.parse(data); - return migrateV1toV2(v1); + if (workflow.meta.version === '1.0.0') { + const v1 = zWorkflowV1.parse(workflow); + workflow = migrateV1toV2(v1); } - if (version === '2.0.0') { - return zWorkflowV2.parse(data); + if (workflow.meta.version === '2.0.0') { + const v2 = zWorkflowV2.parse(workflow); + workflow = migrateV2toV3(v2); } - throw new WorkflowVersionError(t('nodes.unrecognizedWorkflowVersion', { version })); + return workflow as WorkflowV3; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts index 848d2aee77a..5096e588b06 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts @@ -1,6 +1,6 @@ import { parseify } from 'common/util/serialize'; import type { InvocationTemplate } from 'features/nodes/types/invocation'; -import type { WorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowV3 } from 'features/nodes/types/workflow'; import { isWorkflowInvocationNode } from 'features/nodes/types/workflow'; import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate'; import { t } from 'i18next'; @@ -16,7 +16,7 @@ type WorkflowWarning = { }; type ValidateWorkflowResult = { - workflow: WorkflowV2; + workflow: WorkflowV3; warnings: WorkflowWarning[]; }; diff --git a/invokeai/frontend/web/src/features/workflowLibrary/hooks/useSaveWorkflow.ts b/invokeai/frontend/web/src/features/workflowLibrary/hooks/useSaveWorkflow.ts index 5d484b68971..7b49d70213f 100644 --- a/invokeai/frontend/web/src/features/workflowLibrary/hooks/useSaveWorkflow.ts +++ b/invokeai/frontend/web/src/features/workflowLibrary/hooks/useSaveWorkflow.ts @@ -3,7 +3,7 @@ import { useToast } from '@invoke-ai/ui-library'; import { useAppDispatch } from 'app/store/storeHooks'; import { $builtWorkflow } from 'features/nodes/hooks/useWorkflowWatcher'; import { workflowIDChanged, workflowSaved } from 'features/nodes/store/workflowSlice'; -import type { WorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowV3 } from 'features/nodes/types/workflow'; import { workflowUpdated } from 'features/workflowLibrary/store/actions'; import { useCallback, useRef } from 'react'; import { useTranslation } from 'react-i18next'; @@ -18,7 +18,7 @@ type UseSaveLibraryWorkflowReturn = { type UseSaveLibraryWorkflow = () => UseSaveLibraryWorkflowReturn; -export const isWorkflowWithID = (workflow: WorkflowV2): workflow is O.Required => +export const isWorkflowWithID = (workflow: WorkflowV3): workflow is O.Required => Boolean(workflow.id); export const useSaveLibraryWorkflow: UseSaveLibraryWorkflow = () => { diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index da036b6d40a..3393e74d486 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -19,80 +19,14 @@ export type paths = { */ post: operations["parse_dynamicprompts"]; }; - "/api/v1/models/": { - /** - * List Models - * @description Gets a list of models - */ - get: operations["list_models"]; - }; - "/api/v1/models/{base_model}/{model_type}/{model_name}": { - /** - * Delete Model - * @description Delete Model - */ - delete: operations["del_model"]; - /** - * Update Model - * @description Update model contents with a new config. If the model name or base fields are changed, then the model is renamed. - */ - patch: operations["update_model"]; - }; - "/api/v1/models/import": { - /** - * Import Model - * @description Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically - */ - post: operations["import_model"]; - }; - "/api/v1/models/add": { - /** - * Add Model - * @description Add a model using the configuration information appropriate for its type. Only local models can be added by path - */ - post: operations["add_model"]; - }; - "/api/v1/models/convert/{base_model}/{model_type}/{model_name}": { - /** - * Convert Model - * @description Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none. - */ - put: operations["convert_model"]; - }; - "/api/v1/models/search": { - /** Search For Models */ - get: operations["search_for_models"]; - }; - "/api/v1/models/ckpt_confs": { - /** - * List Ckpt Configs - * @description Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT. - */ - get: operations["list_ckpt_configs"]; - }; - "/api/v1/models/sync": { - /** - * Sync To Config - * @description Call after making changes to models.yaml, autoimport directories or models directory to synchronize - * in-memory data structures with disk data structures. - */ - post: operations["sync_to_config"]; - }; - "/api/v1/models/merge/{base_model}": { - /** - * Merge Models - * @description Convert a checkpoint model into a diffusers model - */ - put: operations["merge_models"]; - }; - "/api/v1/model/record/": { + "/api/v2/models/": { /** * List Model Records * @description Get a list of models. */ get: operations["list_model_records"]; }; - "/api/v1/model/record/i/{key}": { + "/api/v2/models/i/{key}": { /** * Get Model Record * @description Get a model record @@ -112,50 +46,76 @@ export type paths = { */ patch: operations["update_model_record"]; }; - "/api/v1/model/record/meta": { + "/api/v2/models/summary": { /** * List Model Summary * @description Gets a page of model summary data. */ get: operations["list_model_summary"]; }; - "/api/v1/model/record/meta/i/{key}": { + "/api/v2/models/meta/i/{key}": { /** * Get Model Metadata * @description Get a model metadata object. */ get: operations["get_model_metadata"]; }; - "/api/v1/model/record/tags": { + "/api/v2/models/tags": { /** * List Tags * @description Get a unique set of all the model tags. */ get: operations["list_tags"]; }; - "/api/v1/model/record/tags/search": { + "/api/v2/models/tags/search": { /** * Search By Metadata Tags * @description Get a list of models. */ get: operations["search_by_metadata_tags"]; }; - "/api/v1/model/record/i/": { + "/api/v2/models/i/": { /** * Add Model Record * @description Add a model using the configuration information appropriate for its type. */ post: operations["add_model_record"]; }; - "/api/v1/model/record/import": { + "/api/v2/models/heuristic_import": { /** - * List Model Install Jobs - * @description Return list of model install jobs. + * Heuristic Import + * @description Install a model using a string identifier. + * + * `source` can be any of the following. + * + * 1. A path on the local filesystem ('C:\users\fred\model.safetensors') + * 2. A Url pointing to a single downloadable model file + * 3. A HuggingFace repo_id with any of the following formats: + * - model/name + * - model/name:fp16:vae + * - model/name::vae -- use default precision + * - model/name:fp16:path/to/model.safetensors + * - model/name::path/to/model.safetensors + * + * `config` is an optional dict containing model configuration values that will override + * the ones that are probed automatically. + * + * `access_token` is an optional access token for use with Urls that require + * authentication. + * + * Models will be downloaded, probed, configured and installed in a + * series of background threads. The return object has `status` attribute + * that can be used to monitor progress. + * + * See the documentation for `import_model_record` for more information on + * interpreting the job information returned by this route. */ - get: operations["list_model_install_jobs"]; + post: operations["heuristic_import_model"]; + }; + "/api/v2/models/install": { /** * Import Model - * @description Add a model using its local path, repo_id, or remote URL. + * @description Install a model using its local path, repo_id, or remote URL. * * Models will be downloaded, probed, configured and installed in a * series of background threads. The return object has `status` attribute @@ -166,32 +126,38 @@ export type paths = { * appropriate value: * * * To install a local path using LocalModelSource, pass a source of form: - * `{ + * ``` + * { * "type": "local", * "path": "/path/to/model", * "inplace": false - * }` - * The "inplace" flag, if true, will register the model in place in its - * current filesystem location. Otherwise, the model will be copied - * into the InvokeAI models directory. + * } + * ``` + * The "inplace" flag, if true, will register the model in place in its + * current filesystem location. Otherwise, the model will be copied + * into the InvokeAI models directory. * * * To install a HuggingFace repo_id using HFModelSource, pass a source of form: - * `{ + * ``` + * { * "type": "hf", * "repo_id": "stabilityai/stable-diffusion-2.0", * "variant": "fp16", * "subfolder": "vae", * "access_token": "f5820a918aaf01" - * }` - * The `variant`, `subfolder` and `access_token` fields are optional. + * } + * ``` + * The `variant`, `subfolder` and `access_token` fields are optional. * * * To install a remote model using an arbitrary URL, pass: - * `{ + * ``` + * { * "type": "url", * "url": "http://www.civitai.com/models/123456", * "access_token": "f5820a918aaf01" - * }` - * The `access_token` field is optonal + * } + * ``` + * The `access_token` field is optonal * * The model's configuration record will be probed and filled in * automatically. To override the default guesses, pass "metadata" @@ -200,26 +166,51 @@ export type paths = { * Installation occurs in the background. Either use list_model_install_jobs() * to poll for completion, or listen on the event bus for the following events: * - * "model_install_running" - * "model_install_completed" - * "model_install_error" + * * "model_install_running" + * * "model_install_completed" + * * "model_install_error" * * On successful completion, the event's payload will contain the field "key" * containing the installed ID of the model. On an error, the event's payload * will contain the fields "error_type" and "error" describing the nature of the * error and its traceback, respectively. */ - post: operations["import_model_record"]; + post: operations["import_model"]; + }; + "/api/v2/models/import": { + /** + * List Model Install Jobs + * @description Return the list of model install jobs. + * + * Install jobs have a numeric `id`, a `status`, and other fields that provide information on + * the nature of the job and its progress. The `status` is one of: + * + * * "waiting" -- Job is waiting in the queue to run + * * "downloading" -- Model file(s) are downloading + * * "running" -- Model has downloaded and the model probing and registration process is running + * * "completed" -- Installation completed successfully + * * "error" -- An error occurred. Details will be in the "error_type" and "error" fields. + * * "cancelled" -- Job was cancelled before completion. + * + * Once completed, information about the model such as its size, base + * model, type, and metadata can be retrieved from the `config_out` + * field. For multi-file models such as diffusers, information on individual files + * can be retrieved from `download_parts`. + * + * See the example and schema below for more information. + */ + get: operations["list_model_install_jobs"]; /** * Prune Model Install Jobs * @description Prune all completed and errored jobs from the install job list. */ patch: operations["prune_model_install_jobs"]; }; - "/api/v1/model/record/import/{id}": { + "/api/v2/models/import/{id}": { /** * Get Model Install Job - * @description Return model install job corresponding to the given source. + * @description Return model install job corresponding to the given source. See the documentation for 'List Model Install Jobs' + * for information on the format of the return value. */ get: operations["get_model_install_job"]; /** @@ -228,7 +219,7 @@ export type paths = { */ delete: operations["cancel_model_install_job"]; }; - "/api/v1/model/record/sync": { + "/api/v2/models/sync": { /** * Sync Models To Config * @description Traverse the models and autoimport directories. @@ -238,17 +229,29 @@ export type paths = { */ patch: operations["sync_models_to_config"]; }; - "/api/v1/model/record/merge": { + "/api/v2/models/convert/{key}": { + /** + * Convert Model + * @description Permanently convert a model into diffusers format, replacing the safetensors version. + * Note that during the conversion process the key and model hash will change. + * The return value is the model configuration for the converted model. + */ + put: operations["convert_model"]; + }; + "/api/v2/models/merge": { /** * Merge - * @description Merge diffusers models. - * - * keys: List of 2-3 model keys to merge together. All models must use the same base type. - * merged_model_name: Name for the merged model [Concat model names] - * alpha: Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5] - * force: If true, force the merge even if the models were generated by different versions of the diffusers library [False] - * interp: Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum] - * merge_dest_directory: Specify a directory to store the merged model in [models directory] + * @description Merge diffusers models. The process is controlled by a set parameters provided in the body of the request. + * ``` + * Argument Description [default] + * -------- ---------------------- + * keys List of 2-3 model keys to merge together. All models must use the same base type. + * merged_model_name Name for the merged model [Concat model names] + * alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5] + * force If true, force the merge even if the models were generated by different versions of the diffusers library [False] + * interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum] + * merge_dest_directory Specify a directory to store the merged model in [models directory] + * ``` */ put: operations["merge"]; }; @@ -680,70 +683,6 @@ export type components = { */ type: "add"; }; - /** - * Adjust Image Hue Plus - * @description Adjusts the Hue of an image by rotating it in the selected color space - */ - AdjustImageHuePlusInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The image to adjust */ - image?: components["schemas"]["ImageField"]; - /** - * Space - * @description Color space in which to rotate hue by polar coords (*: non-invertible) - * @default Okhsl - * @enum {string} - */ - space?: "HSV / HSL / RGB" | "Okhsl" | "Okhsv" | "*Oklch / Oklab" | "*LCh / CIELab" | "*UPLab (w/CIELab_to_UPLab.icc)"; - /** - * Degrees - * @description Degrees by which to rotate image hue - * @default 0 - */ - degrees?: number; - /** - * Preserve Lightness - * @description Whether to preserve CIELAB lightness values - * @default false - */ - preserve_lightness?: boolean; - /** - * Ok Adaptive Gamut - * @description Higher preserves chroma at the expense of lightness (Oklab) - * @default 0.05 - */ - ok_adaptive_gamut?: number; - /** - * Ok High Precision - * @description Use more steps in computing gamut (Oklab/Okhsv/Okhsl) - * @default true - */ - ok_high_precision?: boolean; - /** - * type - * @default img_hue_adjust_plus - * @constant - */ - type: "img_hue_adjust_plus"; - }; /** * AppConfig * @description App Config Response @@ -879,6 +818,12 @@ export type components = { */ type?: "basemetadata"; }; + /** + * BaseModelType + * @description Base model type. + * @enum {string} + */ + BaseModelType: "any" | "sd-1" | "sd-2" | "sdxl" | "sdxl-refiner"; /** Batch */ Batch: { /** @@ -968,6 +913,8 @@ export type components = { * @description Creates a blank image and forwards it to the pipeline */ BlankImageInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -1225,19 +1172,6 @@ export type components = { }; /** Body_import_model */ Body_import_model: { - /** - * Location - * @description A model path, repo_id or URL to import - */ - location: string; - /** - * Prediction Type - * @description Prediction type for SDv2 checkpoints and rare SDv1 checkpoints - */ - prediction_type?: ("v_prediction" | "epsilon" | "sample") | null; - }; - /** Body_import_model_record */ - Body_import_model_record: { /** Source */ source: components["schemas"]["LocalModelSource"] | components["schemas"]["HFModelSource"] | components["schemas"]["CivitaiModelSource"] | components["schemas"]["URLModelSource"]; /** @@ -1278,11 +1212,6 @@ export type components = { */ merge_dest_directory?: string | null; }; - /** Body_merge_models */ - Body_merge_models: { - /** @description Model configuration */ - body: components["schemas"]["MergeModelsBody"]; - }; /** Body_parse_dynamicprompts */ Body_parse_dynamicprompts: { /** @@ -1452,39 +1381,6 @@ export type components = { */ type: "boolean_output"; }; - /** - * BRIA AI Background Removal - * @description Uses the new Bria 1.4 model to remove backgrounds from images. - */ - BriaRemoveBackgroundInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The image to crop */ - image?: components["schemas"]["ImageField"]; - /** - * type - * @default bria_bg_remove - * @constant - */ - type: "bria_bg_remove"; - }; /** * CLIPOutput * @description Base class for invocations that output a CLIP field @@ -1507,11 +1403,18 @@ export type components = { * @description Model config for ClipVision. */ CLIPVisionDiffusersConfig: { - /** Path */ + /** + * Path + * @description filesystem path to the model file or directory + */ path: string; - /** Name */ + /** + * Name + * @description model name + */ name: string; - base: components["schemas"]["invokeai__backend__model_manager__config__BaseModelType"]; + /** @description base model */ + base: components["schemas"]["BaseModelType"]; /** * Type * @default clip_vision @@ -1539,51 +1442,37 @@ export type components = { * @description current fasthash of model contents */ current_hash?: string | null; - /** Description */ + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** * Source - * @description Model download source (URL or repo_id) + * @description model original source (path, URL or repo_id) */ source?: string | null; - }; - /** CLIPVisionModelDiffusersConfig */ - CLIPVisionModelDiffusersConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default clip_vision - * @constant - */ - model_type: "clip_vision"; - /** Path */ - path: string; - /** Description */ - description?: string | null; /** - * Model Format - * @constant + * Last Modified + * @description timestamp for modification time */ - model_format: "diffusers"; - error?: components["schemas"]["ModelError"] | null; + last_modified?: number | null; }; /** CLIPVisionModelField */ CLIPVisionModelField: { /** - * Model Name - * @description Name of the CLIP Vision image encoder model + * Key + * @description Key to the CLIP Vision image encoder model */ - model_name: string; - /** @description Base model (usually 'Any') */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; + key: string; }; /** - * CMYK Color Separation - * @description Get color images from a base color and two others that subtractively mix to obtain it + * CV2 Infill + * @description Infills transparent areas of an image using OpenCV Inpainting */ - CMYKColorSeparationInvocation: { + CV2InfillInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -1603,87 +1492,79 @@ export type components = { * @default true */ use_cache?: boolean; + /** @description The image to infill */ + image?: components["schemas"]["ImageField"]; /** - * Width - * @description Desired image width - * @default 512 - */ - width?: number; - /** - * Height - * @description Desired image height - * @default 512 - */ - height?: number; - /** - * C Value - * @description Desired final cyan value - * @default 0 + * type + * @default infill_cv2 + * @constant */ - c_value?: number; + type: "infill_cv2"; + }; + /** + * Calculate Image Tiles Even Split + * @description Calculate the coordinates and overlaps of tiles that cover a target image shape. + */ + CalculateImageTilesEvenSplitInvocation: { /** - * M Value - * @description Desired final magenta value - * @default 25 + * Id + * @description The id of this instance of an invocation. Must be unique among all instances of invocations. */ - m_value?: number; + id: string; /** - * Y Value - * @description Desired final yellow value - * @default 28 + * Is Intermediate + * @description Whether or not this is an intermediate invocation. + * @default false */ - y_value?: number; + is_intermediate?: boolean; /** - * K Value - * @description Desired final black value - * @default 76 + * Use Cache + * @description Whether or not to use the cache + * @default true */ - k_value?: number; + use_cache?: boolean; /** - * C Split - * @description Desired cyan split point % [0..1.0] - * @default 0.5 + * Image Width + * @description The image width, in pixels, to calculate tiles for. + * @default 1024 */ - c_split?: number; + image_width?: number; /** - * M Split - * @description Desired magenta split point % [0..1.0] - * @default 1 + * Image Height + * @description The image height, in pixels, to calculate tiles for. + * @default 1024 */ - m_split?: number; + image_height?: number; /** - * Y Split - * @description Desired yellow split point % [0..1.0] - * @default 0 + * Num Tiles X + * @description Number of tiles to divide image into on the x axis + * @default 2 */ - y_split?: number; + num_tiles_x?: number; /** - * K Split - * @description Desired black split point % [0..1.0] - * @default 0.5 + * Num Tiles Y + * @description Number of tiles to divide image into on the y axis + * @default 2 */ - k_split?: number; + num_tiles_y?: number; /** - * Profile - * @description CMYK Color Profile - * @default Default - * @enum {string} + * Overlap + * @description The overlap, in pixels, between adjacent tiles. + * @default 128 */ - profile?: "Default" | "PIL"; + overlap?: number; /** * type - * @default cmyk_separation + * @default calculate_image_tiles_even_split * @constant */ - type: "cmyk_separation"; + type: "calculate_image_tiles_even_split"; }; /** - * CMYK Merge - * @description Merge subtractive color channels (CMYK+alpha) + * Calculate Image Tiles + * @description Calculate the coordinates and overlaps of tiles that cover a target image shape. */ - CMYKMergeInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; + CalculateImageTilesInvocation: { /** * Id * @description The id of this instance of an invocation. Must be unique among all instances of invocations. @@ -1701,284 +1582,16 @@ export type components = { * @default true */ use_cache?: boolean; - /** @description The c channel */ - c_channel?: components["schemas"]["ImageField"] | null; - /** @description The m channel */ - m_channel?: components["schemas"]["ImageField"] | null; - /** @description The y channel */ - y_channel?: components["schemas"]["ImageField"] | null; - /** @description The k channel */ - k_channel?: components["schemas"]["ImageField"] | null; - /** @description The alpha channel */ - alpha_channel?: components["schemas"]["ImageField"] | null; - /** - * Profile - * @description CMYK Color Profile - * @default Default - * @enum {string} - */ - profile?: "Default" | "PIL"; /** - * type - * @default cmyk_merge - * @constant + * Image Width + * @description The image width, in pixels, to calculate tiles for. + * @default 1024 */ - type: "cmyk_merge"; - }; - /** - * CMYKSeparationOutput - * @description Base class for invocations that output four L-mode images (C, M, Y, K) - */ - CMYKSeparationOutput: { - /** @description Blank image of the specified color */ - color_image: components["schemas"]["ImageField"]; + image_width?: number; /** - * Width - * @description The width of the image in pixels - */ - width: number; - /** - * Height - * @description The height of the image in pixels - */ - height: number; - /** @description Blank image of the first separated color */ - part_a: components["schemas"]["ImageField"]; - /** - * Rgb Red A - * @description R value of color part A - */ - rgb_red_a: number; - /** - * Rgb Green A - * @description G value of color part A - */ - rgb_green_a: number; - /** - * Rgb Blue A - * @description B value of color part A - */ - rgb_blue_a: number; - /** @description Blank image of the second separated color */ - part_b: components["schemas"]["ImageField"]; - /** - * Rgb Red B - * @description R value of color part B - */ - rgb_red_b: number; - /** - * Rgb Green B - * @description G value of color part B - */ - rgb_green_b: number; - /** - * Rgb Blue B - * @description B value of color part B - */ - rgb_blue_b: number; - /** - * type - * @default cmyk_separation_output - * @constant - */ - type: "cmyk_separation_output"; - }; - /** - * CMYK Split - * @description Split an image into subtractive color channels (CMYK+alpha) - */ - CMYKSplitInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The image to halftone */ - image?: components["schemas"]["ImageField"]; - /** - * Profile - * @description CMYK Color Profile - * @default Default - * @enum {string} - */ - profile?: "Default" | "PIL"; - /** - * type - * @default cmyk_split - * @constant - */ - type: "cmyk_split"; - }; - /** - * CMYKSplitOutput - * @description Base class for invocations that output four L-mode images (C, M, Y, K) - */ - CMYKSplitOutput: { - /** @description Grayscale image of the cyan channel */ - c_channel: components["schemas"]["ImageField"]; - /** @description Grayscale image of the magenta channel */ - m_channel: components["schemas"]["ImageField"]; - /** @description Grayscale image of the yellow channel */ - y_channel: components["schemas"]["ImageField"]; - /** @description Grayscale image of the k channel */ - k_channel: components["schemas"]["ImageField"]; - /** @description Grayscale image of the alpha channel */ - alpha_channel: components["schemas"]["ImageField"]; - /** - * Width - * @description The width of the image in pixels - */ - width: number; - /** - * Height - * @description The height of the image in pixels - */ - height: number; - /** - * type - * @default cmyk_split_output - * @constant - */ - type: "cmyk_split_output"; - }; - /** - * CV2 Infill - * @description Infills transparent areas of an image using OpenCV Inpainting - */ - CV2InfillInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The image to infill */ - image?: components["schemas"]["ImageField"]; - /** - * type - * @default infill_cv2 - * @constant - */ - type: "infill_cv2"; - }; - /** - * Calculate Image Tiles Even Split - * @description Calculate the coordinates and overlaps of tiles that cover a target image shape. - */ - CalculateImageTilesEvenSplitInvocation: { - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** - * Image Width - * @description The image width, in pixels, to calculate tiles for. - * @default 1024 - */ - image_width?: number; - /** - * Image Height - * @description The image height, in pixels, to calculate tiles for. - * @default 1024 - */ - image_height?: number; - /** - * Num Tiles X - * @description Number of tiles to divide image into on the x axis - * @default 2 - */ - num_tiles_x?: number; - /** - * Num Tiles Y - * @description Number of tiles to divide image into on the y axis - * @default 2 - */ - num_tiles_y?: number; - /** - * Overlap - * @description The overlap, in pixels, between adjacent tiles. - * @default 128 - */ - overlap?: number; - /** - * type - * @default calculate_image_tiles_even_split - * @constant - */ - type: "calculate_image_tiles_even_split"; - }; - /** - * Calculate Image Tiles - * @description Calculate the coordinates and overlaps of tiles that cover a target image shape. - */ - CalculateImageTilesInvocation: { - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** - * Image Width - * @description The image width, in pixels, to calculate tiles for. - * @default 1024 - */ - image_width?: number; - /** - * Image Height - * @description The image height, in pixels, to calculate tiles for. - * @default 1024 + * Image Height + * @description The image height, in pixels, to calculate tiles for. + * @default 1024 */ image_height?: number; /** @@ -2095,6 +1708,8 @@ export type components = { * @description Canny edge detection for ControlNet */ CannyImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -2482,6 +2097,8 @@ export type components = { * using a mask to only color-correct certain regions of the target image. */ ColorCorrectInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -2590,6 +2207,8 @@ export type components = { * @description Generates a color map from the provided image */ ColorMapImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -2797,6 +2416,8 @@ export type components = { * @description Applies content shuffle processing to image */ ContentShuffleImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -2899,11 +2520,18 @@ export type components = { * @description Model config for ControlNet models (diffusers version). */ ControlNetCheckpointConfig: { - /** Path */ + /** + * Path + * @description filesystem path to the model file or directory + */ path: string; - /** Name */ + /** + * Name + * @description model name + */ name: string; - base: components["schemas"]["invokeai__backend__model_manager__config__BaseModelType"]; + /** @description base model */ + base: components["schemas"]["BaseModelType"]; /** * Type * @default controlnet @@ -2932,13 +2560,21 @@ export type components = { * @description current fasthash of model contents */ current_hash?: string | null; - /** Description */ + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** * Source - * @description Model download source (URL or repo_id) + * @description model original source (path, URL or repo_id) */ source?: string | null; + /** + * Last Modified + * @description timestamp for modification time + */ + last_modified?: number | null; /** * Config * @description path to the checkpoint model config file @@ -2950,11 +2586,18 @@ export type components = { * @description Model config for ControlNet models (diffusers version). */ ControlNetDiffusersConfig: { - /** Path */ + /** + * Path + * @description filesystem path to the model file or directory + */ path: string; - /** Name */ + /** + * Name + * @description model name + */ name: string; - base: components["schemas"]["invokeai__backend__model_manager__config__BaseModelType"]; + /** @description base model */ + base: components["schemas"]["BaseModelType"]; /** * Type * @default controlnet @@ -2983,13 +2626,23 @@ export type components = { * @description current fasthash of model contents */ current_hash?: string | null; - /** Description */ + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** * Source - * @description Model download source (URL or repo_id) + * @description model original source (path, URL or repo_id) */ source?: string | null; + /** + * Last Modified + * @description timestamp for modification time + */ + last_modified?: number | null; + /** @default */ + repo_variant?: components["schemas"]["ModelRepoVariant"] | null; }; /** * ControlNet @@ -3056,64 +2709,16 @@ export type components = { */ type: "controlnet"; }; - /** ControlNetModelCheckpointConfig */ - ControlNetModelCheckpointConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default controlnet - * @constant - */ - model_type: "controlnet"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - /** - * Model Format - * @constant - */ - model_format: "checkpoint"; - error?: components["schemas"]["ModelError"] | null; - /** Config */ - config: string; - }; - /** ControlNetModelDiffusersConfig */ - ControlNetModelDiffusersConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default controlnet - * @constant - */ - model_type: "controlnet"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - /** - * Model Format - * @constant - */ - model_format: "diffusers"; - error?: components["schemas"]["ModelError"] | null; - }; /** * ControlNetModelField * @description ControlNet model field */ ControlNetModelField: { /** - * Model Name - * @description Name of the ControlNet model + * Key + * @description Model config record key for the ControlNet model */ - model_name: string; - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; + key: string; }; /** * ControlOutput @@ -3442,6 +3047,8 @@ export type components = { * @description Simple inpaint using opencv. */ CvInpaintInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -3477,6 +3084,8 @@ export type components = { * @description Generates an openpose pose from an image using DWPose */ DWOpenposeImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -3677,6 +3286,8 @@ export type components = { * @description Generates a depth map based on the Depth Anything algorithm */ DepthAnythingImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -3910,6 +3521,8 @@ export type components = { * @description Upscales an image using RealESRGAN. */ ESRGANInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -3996,39 +3609,6 @@ export type components = { */ priority: number; }; - /** - * Equivalent Achromatic Lightness - * @description Calculate Equivalent Achromatic Lightness from image - */ - EquivalentAchromaticLightnessInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description Image from which to get channel */ - image?: components["schemas"]["ImageField"]; - /** - * type - * @default ealightness - * @constant - */ - type: "ealightness"; - }; /** ExposedField */ ExposedField: { /** Nodeid */ @@ -4041,6 +3621,8 @@ export type components = { * @description Outputs an image with detected face IDs printed on each face. For use with other FaceTools. */ FaceIdentifierInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -4281,39 +3863,6 @@ export type components = { */ y: number; }; - /** - * Flatten Histogram (Grayscale) - * @description Scales the values of an L-mode image by scaling them to the full range 0..255 in equal proportions - */ - FlattenHistogramMono: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description Single-channel image for which to flatten the histogram */ - image?: components["schemas"]["ImageField"]; - /** - * type - * @default flatten_histogram_mono - * @constant - */ - type: "flatten_histogram_mono"; - }; /** * Float Collection Primitive * @description A collection of float primitive values @@ -4663,7 +4212,7 @@ export type components = { * @description The nodes in this graph */ nodes?: { - [key: string]: components["schemas"]["ImageDilateOrErodeInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["HandDepthMeshGraphormerProcessor"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["TextToMaskClipsegInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LinearUIOutputInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["OffsetLatentsInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["CMYKColorSeparationInvocation"] | components["schemas"]["NoiseImage2DInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["EquivalentAchromaticLightnessInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageBlendInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["ImageRotateInvocation"] | components["schemas"]["ShadowsHighlightsMidtonesMaskInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["BriaRemoveBackgroundInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["ImageValueThresholdsInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["NoiseSpectralInvocation"] | components["schemas"]["TextMaskInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["MaskedBlendLatentsInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["CMYKMergeInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["TextToMaskClipsegAdvancedInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageCompositorInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["FlattenHistogramMono"] | components["schemas"]["AdjustImageHuePlusInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["CMYKSplitInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["ImageEnhanceInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["LatentConsistencyInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["ImageOffsetInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["InfillColorInvocation"]; + [key: string]: components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"]; }; /** * Edges @@ -4700,7 +4249,7 @@ export type components = { * @description The results of node executions */ results: { - [key: string]: components["schemas"]["LoraLoaderOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["CMYKSplitOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["CMYKSeparationOutput"] | components["schemas"]["String2Output"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["ONNXModelLoaderOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["HandDepthOutput"] | components["schemas"]["ShadowsHighlightsMidtonesMasksOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["ColorCollectionOutput"]; + [key: string]: components["schemas"]["ImageCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["String2Output"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["CLIPOutput"]; }; /** * Errors @@ -4791,88 +4340,13 @@ export type components = { /** Detail */ detail?: components["schemas"]["ValidationError"][]; }; - /** - * Hand Depth w/ MeshGraphormer - * @description Generate hand depth maps to inpaint with using ControlNet - */ - HandDepthMeshGraphormerProcessor: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The image to process */ - image?: components["schemas"]["ImageField"]; - /** - * Resolution - * @description Pixel resolution for output image - * @default 512 - */ - resolution?: number; - /** - * Mask Padding - * @description Amount to pad the hand mask by - * @default 30 - */ - mask_padding?: number; - /** - * Offload - * @description Offload model after usage - * @default false - */ - offload?: boolean; - /** - * type - * @default hand_depth_mesh_graphormer_image_processor - * @constant - */ - type: "hand_depth_mesh_graphormer_image_processor"; - }; - /** - * HandDepthOutput - * @description Base class for to output Meshgraphormer results - */ - HandDepthOutput: { - /** @description Improved hands depth map */ - image: components["schemas"]["ImageField"]; - /** @description Hands area mask */ - mask: components["schemas"]["ImageField"]; - /** - * Width - * @description The width of the depth map in pixels - */ - width: number; - /** - * Height - * @description The height of the depth map in pixels - */ - height: number; - /** - * type - * @default meshgraphormer_output - * @constant - */ - type: "meshgraphormer_output"; - }; /** * HED (softedge) Processor * @description Applies HED edge detection to image */ HedImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -4969,11 +4443,18 @@ export type components = { * @description Model config for IP Adaptor format models. */ IPAdapterConfig: { - /** Path */ + /** + * Path + * @description filesystem path to the model file or directory + */ path: string; - /** Name */ + /** + * Name + * @description model name + */ name: string; - base: components["schemas"]["invokeai__backend__model_manager__config__BaseModelType"]; + /** @description base model */ + base: components["schemas"]["BaseModelType"]; /** * Type * @default ip_adapter @@ -5001,13 +4482,23 @@ export type components = { * @description current fasthash of model contents */ current_hash?: string | null; - /** Description */ + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** * Source - * @description Model download source (URL or repo_id) + * @description model original source (path, URL or repo_id) */ source?: string | null; + /** + * Last Modified + * @description timestamp for modification time + */ + last_modified?: number | null; + /** Image Encoder Model Id */ + image_encoder_model_id: string; }; /** IPAdapterField */ IPAdapterField: { @@ -5124,34 +4615,10 @@ export type components = { /** IPAdapterModelField */ IPAdapterModelField: { /** - * Model Name - * @description Name of the IP-Adapter model - */ - model_name: string; - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - }; - /** IPAdapterModelInvokeAIConfig */ - IPAdapterModelInvokeAIConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default ip_adapter - * @constant - */ - model_type: "ip_adapter"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - /** - * Model Format - * @constant + * Key + * @description Key to the IP-Adapter model */ - model_format: "invokeai"; - error?: components["schemas"]["ModelError"] | null; + key: string; }; /** IPAdapterOutput */ IPAdapterOutput: { @@ -5238,92 +4705,13 @@ export type components = { */ type: "ideal_size_output"; }; - /** - * Image Layer Blend - * @description Blend two images together, with optional opacity, mask, and blend modes - */ - ImageBlendInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The top image to blend */ - layer_upper?: components["schemas"]["ImageField"]; - /** - * Blend Mode - * @description Available blend modes - * @default Normal - * @enum {string} - */ - blend_mode?: "Normal" | "Lighten Only" | "Darken Only" | "Lighten Only (EAL)" | "Darken Only (EAL)" | "Hue" | "Saturation" | "Color" | "Luminosity" | "Linear Dodge (Add)" | "Subtract" | "Multiply" | "Divide" | "Screen" | "Overlay" | "Linear Burn" | "Difference" | "Hard Light" | "Soft Light" | "Vivid Light" | "Linear Light" | "Color Burn" | "Color Dodge"; - /** - * Opacity - * @description Desired opacity of the upper layer - * @default 1 - */ - opacity?: number; - /** @description Optional mask, used to restrict areas from blending */ - mask?: components["schemas"]["ImageField"] | null; - /** - * Fit To Width - * @description Scale upper layer to fit base width - * @default false - */ - fit_to_width?: boolean; - /** - * Fit To Height - * @description Scale upper layer to fit base height - * @default true - */ - fit_to_height?: boolean; - /** @description The bottom image to blend */ - layer_base?: components["schemas"]["ImageField"]; - /** - * Color Space - * @description Available color spaces for blend computations - * @default Linear RGB - * @enum {string} - */ - color_space?: "RGB" | "Linear RGB" | "HSL (RGB)" | "HSV (RGB)" | "Okhsl" | "Okhsv" | "Oklch (Oklab)" | "LCh (CIELab)"; - /** - * Adaptive Gamut - * @description Adaptive gamut clipping (0=off). Higher prioritizes chroma over lightness - * @default 0 - */ - adaptive_gamut?: number; - /** - * High Precision - * @description Use more steps in computing gamut when possible - * @default true - */ - high_precision?: boolean; - /** - * type - * @default img_blend - * @constant - */ - type: "img_blend"; - }; /** * Blur Image * @description Blurs an image */ ImageBlurInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -5382,6 +4770,8 @@ export type components = { * @description Gets a channel from an image. */ ImageChannelInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -5422,6 +4812,8 @@ export type components = { * @description Scale a specific color channel of an image. */ ImageChannelMultiplyInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -5473,6 +4865,8 @@ export type components = { * @description Add or subtract a value from a specific color channel of an image. */ ImageChannelOffsetInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -5565,10 +4959,12 @@ export type components = { type: "image_collection_output"; }; /** - * Image Compositor - * @description Removes backdrop from subject image then overlays subject on background image + * Convert Image Mode + * @description Converts an image to a different mode. */ - ImageCompositorInvocation: { + ImageConvertInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -5588,98 +4984,29 @@ export type components = { * @default true */ use_cache?: boolean; - /** @description Image of the subject on a plain monochrome background */ - image_subject?: components["schemas"]["ImageField"]; - /** @description Image of a background scene */ - image_background?: components["schemas"]["ImageField"]; - /** - * Chroma Key - * @description Can be empty for corner flood select, or CSS-3 color or tuple - * @default - */ - chroma_key?: string; - /** - * Threshold - * @description Subject isolation flood-fill threshold - * @default 50 - */ - threshold?: number; - /** - * Fill X - * @description Scale base subject image to fit background width - * @default false - */ - fill_x?: boolean; - /** - * Fill Y - * @description Scale base subject image to fit background height - * @default true - */ - fill_y?: boolean; - /** - * X Offset - * @description x-offset for the subject - * @default 0 - */ - x_offset?: number; + /** @description The image to convert */ + image?: components["schemas"]["ImageField"]; /** - * Y Offset - * @description y-offset for the subject - * @default 0 + * Mode + * @description The mode to convert to + * @default L + * @enum {string} */ - y_offset?: number; + mode?: "L" | "RGB" | "RGBA" | "CMYK" | "YCbCr" | "LAB" | "HSV" | "I" | "F"; /** * type - * @default img_composite + * @default img_conv * @constant */ - type: "img_composite"; + type: "img_conv"; }; /** - * Convert Image Mode - * @description Converts an image to a different mode. + * Crop Image + * @description Crops an image to a specified box. The box can be outside of the image. */ - ImageConvertInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The image to convert */ - image?: components["schemas"]["ImageField"]; - /** - * Mode - * @description The mode to convert to - * @default L - * @enum {string} - */ - mode?: "L" | "RGB" | "RGBA" | "CMYK" | "YCbCr" | "LAB" | "HSV" | "I" | "F"; - /** - * type - * @default img_conv - * @constant - */ - type: "img_conv"; - }; - /** - * Crop Image - * @description Crops an image to a specified box. The box can be outside of the image. - */ - ImageCropInvocation: { + ImageCropInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -5812,127 +5139,6 @@ export type components = { */ board_id?: string | null; }; - /** - * Image Dilate or Erode - * @description Dilate (expand) or erode (contract) an image - */ - ImageDilateOrErodeInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The image from which to create a mask */ - image?: components["schemas"]["ImageField"]; - /** - * Lightness Only - * @description If true, only applies to image lightness (CIELa*b*) - * @default false - */ - lightness_only?: boolean; - /** - * Radius W - * @description Width (in pixels) by which to dilate(expand) or erode (contract) the image - * @default 4 - */ - radius_w?: number; - /** - * Radius H - * @description Height (in pixels) by which to dilate(expand) or erode (contract) the image - * @default 4 - */ - radius_h?: number; - /** - * Mode - * @description How to operate on the image - * @default Dilate - * @enum {string} - */ - mode?: "Dilate" | "Erode"; - /** - * type - * @default img_dilate_erode - * @constant - */ - type: "img_dilate_erode"; - }; - /** - * Enhance Image - * @description Applies processing from PIL's ImageEnhance module. - */ - ImageEnhanceInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The image for which to apply processing */ - image?: components["schemas"]["ImageField"]; - /** - * Invert - * @description Whether to invert the image colors - * @default false - */ - invert?: boolean; - /** - * Color - * @description Color enhancement factor - * @default 1 - */ - color?: number; - /** - * Contrast - * @description Contrast enhancement factor - * @default 1 - */ - contrast?: number; - /** - * Brightness - * @description Brightness enhancement factor - * @default 1 - */ - brightness?: number; - /** - * Sharpness - * @description Sharpness enhancement factor - * @default 1 - */ - sharpness?: number; - /** - * type - * @default img_enhance - * @constant - */ - type: "img_enhance"; - }; /** * ImageField * @description An image primitive field @@ -5949,6 +5155,8 @@ export type components = { * @description Adjusts the Hue of an image. */ ImageHueAdjustmentInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -5988,6 +5196,8 @@ export type components = { * @description Inverse linear interpolation of all pixels of an image */ ImageInverseLerpInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6064,6 +5274,8 @@ export type components = { * @description Linear interpolation of all pixels of an image */ ImageLerpInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6109,6 +5321,8 @@ export type components = { * @description Multiplies two images together using `PIL.ImageChops.multiply()`. */ ImageMultiplyInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6144,6 +5358,8 @@ export type components = { * @description Add blur to NSFW-flagged images */ ImageNSFWBlurInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6172,57 +5388,6 @@ export type components = { */ type: "img_nsfw"; }; - /** - * Offset Image - * @description Offsets an image by a given percentage (or pixel amount). - */ - ImageOffsetInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** - * As Pixels - * @description Interpret offsets as pixels rather than percentages - * @default false - */ - as_pixels?: boolean; - /** @description Image to be offset */ - image?: components["schemas"]["ImageField"]; - /** - * X Offset - * @description x-offset for the subject - * @default 0.5 - */ - x_offset?: number; - /** - * Y Offset - * @description y-offset for the subject - * @default 0.5 - */ - y_offset?: number; - /** - * type - * @default offset_image - * @constant - */ - type: "offset_image"; - }; /** * ImageOutput * @description Base class for nodes that output a single image @@ -6252,6 +5417,8 @@ export type components = { * @description Pastes an image into another image. */ ImagePasteInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6337,6 +5504,8 @@ export type components = { * @description Resizes an image to specific dimensions */ ImageResizeInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6384,68 +5553,13 @@ export type components = { */ type: "img_resize"; }; - /** - * Rotate/Flip Image - * @description Rotates an image by a given angle (in degrees clockwise). - */ - ImageRotateInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description Image to be rotated clockwise */ - image?: components["schemas"]["ImageField"]; - /** - * Degrees - * @description Angle (in degrees clockwise) by which to rotate - * @default 90 - */ - degrees?: number; - /** - * Expand To Fit - * @description If true, extends the image boundary to fit the rotated content - * @default true - */ - expand_to_fit?: boolean; - /** - * Flip Horizontal - * @description If true, flips the image horizontally - * @default false - */ - flip_horizontal?: boolean; - /** - * Flip Vertical - * @description If true, flips the image vertically - * @default false - */ - flip_vertical?: boolean; - /** - * type - * @default rotate_image - * @constant - */ - type: "rotate_image"; - }; /** * Scale Image * @description Scales an image by a factor */ ImageScaleInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6554,10 +5668,12 @@ export type components = { thumbnail_url: string; }; /** - * Image Value Thresholds - * @description Clip image to pure black/white past specified thresholds + * Add Invisible Watermark + * @description Add an invisible watermark to an image */ - ImageValueThresholdsInvocation: { + ImageWatermarkInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6577,105 +5693,44 @@ export type components = { * @default true */ use_cache?: boolean; - /** @description The image from which to create a mask */ + /** @description The image to check */ image?: components["schemas"]["ImageField"]; /** - * Invert Output - * @description Make light areas dark and vice versa - * @default false - */ - invert_output?: boolean; - /** - * Renormalize Values - * @description Rescale remaining values from minimum to maximum - * @default false - */ - renormalize_values?: boolean; - /** - * Lightness Only - * @description If true, only applies to image lightness (CIELa*b*) - * @default false + * Text + * @description Watermark text + * @default InvokeAI */ - lightness_only?: boolean; + text?: string; /** - * Threshold Upper - * @description Threshold above which will be set to full value - * @default 0.5 + * type + * @default img_watermark + * @constant */ - threshold_upper?: number; + type: "img_watermark"; + }; + /** ImagesDownloaded */ + ImagesDownloaded: { /** - * Threshold Lower - * @description Threshold below which will be set to minimum value - * @default 0.5 + * Response + * @description If defined, the message to display to the user when images begin downloading */ - threshold_lower?: number; + response: string | null; + }; + /** ImagesUpdatedFromListResult */ + ImagesUpdatedFromListResult: { /** - * type - * @default img_val_thresholds - * @constant + * Updated Image Names + * @description The image names that were updated */ - type: "img_val_thresholds"; + updated_image_names: string[]; }; /** - * Add Invisible Watermark - * @description Add an invisible watermark to an image - */ - ImageWatermarkInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The image to check */ - image?: components["schemas"]["ImageField"]; - /** - * Text - * @description Watermark text - * @default InvokeAI - */ - text?: string; - /** - * type - * @default img_watermark - * @constant - */ - type: "img_watermark"; - }; - /** ImagesDownloaded */ - ImagesDownloaded: { - /** - * Response - * @description If defined, the message to display to the user when images begin downloading - */ - response: string | null; - }; - /** ImagesUpdatedFromListResult */ - ImagesUpdatedFromListResult: { - /** - * Updated Image Names - * @description The image names that were updated - */ - updated_image_names: string[]; - }; - /** - * Solid Color Infill - * @description Infills transparent areas of an image with a solid color + * Solid Color Infill + * @description Infills transparent areas of an image with a solid color */ InfillColorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6719,6 +5774,8 @@ export type components = { * @description Infills transparent areas of an image using the PatchMatch algorithm */ InfillPatchMatchInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6765,6 +5822,8 @@ export type components = { * @description Infills transparent areas of an image with tiles of the image */ InfillTileInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -7065,6 +6124,8 @@ export type components = { * @description Infills transparent areas of an image using the LaMa model */ LaMaInfillInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -7093,96 +6154,6 @@ export type components = { */ type: "infill_lama"; }; - /** - * Latent Consistency MonoNode - * @description Wrapper node around diffusers LatentConsistencyTxt2ImgPipeline - */ - LatentConsistencyInvocation: { - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** - * Prompt - * @description The prompt to use - */ - prompt?: string; - /** - * Num Inference Steps - * @description The number of inference steps to use, 4-8 recommended - * @default 8 - */ - num_inference_steps?: number; - /** - * Guidance Scale - * @description The guidance scale to use - * @default 8 - */ - guidance_scale?: number; - /** - * Batches - * @description The number of batches to use - * @default 1 - */ - batches?: number; - /** - * Images Per Batch - * @description The number of images per batch to use - * @default 1 - */ - images_per_batch?: number; - /** - * Seeds - * @description List of noise seeds to use - */ - seeds?: number[]; - /** - * Lcm Origin Steps - * @description The lcm origin steps to use - * @default 50 - */ - lcm_origin_steps?: number; - /** - * Width - * @description The width to use - * @default 512 - */ - width?: number; - /** - * Height - * @description The height to use - * @default 512 - */ - height?: number; - /** - * Precision - * @description floating point precision - * @default fp16 - * @enum {string} - */ - precision?: "fp16" | "fp32"; - /** @description The board to save the image to */ - board?: components["schemas"]["BoardField"]; - /** - * type - * @default latent_consistency_mononode - * @constant - */ - type: "latent_consistency_mononode"; - }; /** * Latents Collection Primitive * @description A collection of latents tensor primitive values @@ -7310,6 +6281,8 @@ export type components = { * @description Generates an image from latents. */ LatentsToImageInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -7357,6 +6330,8 @@ export type components = { * @description Applies leres processing to image */ LeresImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -7441,46 +6416,13 @@ export type components = { /** @description Type of commercial use allowed or 'No' if no commercial use is allowed. */ AllowCommercialUse?: components["schemas"]["CommercialUsage"]; }; - /** - * Linear UI Image Output - * @description Handles Linear UI Image Outputting tasks. - */ - LinearUIOutputInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default false - */ - use_cache?: boolean; - /** @description The image to process */ - image?: components["schemas"]["ImageField"]; - /** @description The board to save the image to */ - board?: components["schemas"]["BoardField"] | null; - /** - * type - * @default linear_ui_output - * @constant - */ - type: "linear_ui_output"; - }; /** * Lineart Anime Processor * @description Applies line art anime processing to image */ LineartAnimeImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -7526,6 +6468,8 @@ export type components = { * @description Applies line art processing to image */ LineartImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -7577,11 +6521,18 @@ export type components = { * @description Model config for LoRA/Lycoris models. */ LoRAConfig: { - /** Path */ + /** + * Path + * @description filesystem path to the model file or directory + */ path: string; - /** Name */ + /** + * Name + * @description model name + */ name: string; - base: components["schemas"]["invokeai__backend__model_manager__config__BaseModelType"]; + /** @description base model */ + base: components["schemas"]["BaseModelType"]; /** * Type * @default lora @@ -7609,13 +6560,21 @@ export type components = { * @description current fasthash of model contents */ current_hash?: string | null; - /** Description */ + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** * Source - * @description Model download source (URL or repo_id) + * @description model original source (path, URL or repo_id) */ source?: string | null; + /** + * Last Modified + * @description timestamp for modification time + */ + last_modified?: number | null; }; /** * LoRAMetadataField @@ -7630,42 +6589,17 @@ export type components = { */ weight: number; }; - /** LoRAModelConfig */ - LoRAModelConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default lora - * @constant - */ - model_type: "lora"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - model_format: components["schemas"]["LoRAModelFormat"]; - error?: components["schemas"]["ModelError"] | null; - }; /** * LoRAModelField * @description LoRA model field */ LoRAModelField: { /** - * Model Name - * @description Name of the LoRA model + * Key + * @description LoRA model key */ - model_name: string; - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; + key: string; }; - /** - * LoRAModelFormat - * @enum {string} - */ - LoRAModelFormat: "lycoris" | "diffusers"; /** * LocalModelSource * @description A local file or directory path. @@ -7693,16 +6627,12 @@ export type components = { /** LoraInfo */ LoraInfo: { /** - * Model Name - * @description Info to load submodel + * Key + * @description Key of model as returned by ModelRecordServiceBase.get_model() */ - model_name: string; - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** @description Info to load submodel */ - model_type: components["schemas"]["invokeai__backend__model_management__models__base__ModelType"]; + key: string; /** @description Info to load submodel */ - submodel?: components["schemas"]["SubModelType"] | null; + submodel_type?: components["schemas"]["SubModelType"] | null; /** * Weight * @description Lora's weight which to use when apply to model @@ -7786,11 +6716,18 @@ export type components = { * @description Model config for main checkpoint models. */ MainCheckpointConfig: { - /** Path */ + /** + * Path + * @description filesystem path to the model file or directory + */ path: string; - /** Name */ + /** + * Name + * @description model name + */ name: string; - base: components["schemas"]["invokeai__backend__model_manager__config__BaseModelType"]; + /** @description base model */ + base: components["schemas"]["BaseModelType"]; /** * Type * @default main @@ -7819,17 +6756,32 @@ export type components = { * @description current fasthash of model contents */ current_hash?: string | null; - /** Description */ + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** * Source - * @description Model download source (URL or repo_id) + * @description model original source (path, URL or repo_id) */ source?: string | null; + /** + * Last Modified + * @description timestamp for modification time + */ + last_modified?: number | null; /** Vae */ vae?: string | null; /** @default normal */ - variant?: components["schemas"]["invokeai__backend__model_manager__config__ModelVariantType"]; + variant?: components["schemas"]["ModelVariantType"]; + /** @default epsilon */ + prediction_type?: components["schemas"]["SchedulerPredictionType"]; + /** + * Upcast Attention + * @default false + */ + upcast_attention?: boolean; /** * Ztsnr Training * @default false @@ -7846,11 +6798,18 @@ export type components = { * @description Model config for main diffusers models. */ MainDiffusersConfig: { - /** Path */ + /** + * Path + * @description filesystem path to the model file or directory + */ path: string; - /** Name */ + /** + * Name + * @description model name + */ name: string; - base: components["schemas"]["invokeai__backend__model_manager__config__BaseModelType"]; + /** @description base model */ + base: components["schemas"]["BaseModelType"]; /** * Type * @default main @@ -7879,29 +6838,39 @@ export type components = { * @description current fasthash of model contents */ current_hash?: string | null; - /** Description */ + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** * Source - * @description Model download source (URL or repo_id) + * @description model original source (path, URL or repo_id) */ source?: string | null; + /** + * Last Modified + * @description timestamp for modification time + */ + last_modified?: number | null; /** Vae */ vae?: string | null; /** @default normal */ - variant?: components["schemas"]["invokeai__backend__model_manager__config__ModelVariantType"]; - /** - * Ztsnr Training - * @default false - */ - ztsnr_training?: boolean; + variant?: components["schemas"]["ModelVariantType"]; /** @default epsilon */ - prediction_type?: components["schemas"]["invokeai__backend__model_manager__config__SchedulerPredictionType"]; + prediction_type?: components["schemas"]["SchedulerPredictionType"]; /** * Upcast Attention * @default false */ upcast_attention?: boolean; + /** + * Ztsnr Training + * @default false + */ + ztsnr_training?: boolean; + /** @default */ + repo_variant?: components["schemas"]["ModelRepoVariant"] | null; }; /** * MainModelField @@ -7909,14 +6878,10 @@ export type components = { */ MainModelField: { /** - * Model Name - * @description Name of the model + * Key + * @description Model key */ - model_name: string; - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** @description Model Type */ - model_type: components["schemas"]["invokeai__backend__model_management__models__base__ModelType"]; + key: string; }; /** * Main Model @@ -7954,6 +6919,8 @@ export type components = { * @description Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`. */ MaskCombineInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -7989,6 +6956,8 @@ export type components = { * @description Applies an edge mask to an image */ MaskEdgeInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -8042,6 +7011,8 @@ export type components = { * @description Extracts the alpha channel of an image as a mask. */ MaskFromAlphaInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -8077,51 +7048,12 @@ export type components = { type: "tomask"; }; /** - * Blend Latents/Noise (Masked) - * @description Blend two latents using a given alpha and mask. Latents must have same size. - */ - MaskedBlendLatentsInvocation: { - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description Latents tensor */ - latents_a?: components["schemas"]["LatentsField"]; - /** @description Latents tensor */ - latents_b?: components["schemas"]["LatentsField"]; - /** @description Mask for blending in latents B */ - mask?: components["schemas"]["ImageField"]; - /** - * Alpha - * @description Blending factor. 0.0 = use input A only, 1.0 = use input B only, 0.5 = 50% mix of input A and input B. - * @default 0.5 - */ - alpha?: number; - /** - * type - * @default lmblend - * @constant - */ - type: "lmblend"; - }; - /** - * Mediapipe Face Processor - * @description Applies mediapipe face processing to image + * Mediapipe Face Processor + * @description Applies mediapipe face processing to image */ MediapipeFaceProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -8201,43 +7133,13 @@ export type components = { */ type: "merge_metadata"; }; - /** MergeModelsBody */ - MergeModelsBody: { - /** - * Model Names - * @description model name - */ - model_names: string[]; - /** - * Merged Model Name - * @description Name of destination model - */ - merged_model_name: string | null; - /** - * Alpha - * @description Alpha weighting strength to apply to 2d and 3d models - * @default 0.5 - */ - alpha?: number | null; - /** @description Interpolation method */ - interp: components["schemas"]["MergeInterpolationMethod"] | null; - /** - * Force - * @description Force merging of models created with different versions of diffusers - * @default false - */ - force?: boolean | null; - /** - * Merge Dest Directory - * @description Save the merged model to the designated directory (with 'merged_model_name' appended) - */ - merge_dest_directory?: string | null; - }; /** * Merge Tiles to Image * @description Merge multiple tile images into a single image. */ MergeTilesToImageInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -8404,6 +7306,8 @@ export type components = { * @description Applies Midas depth processing to image */ MidasDepthImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -8449,6 +7353,8 @@ export type components = { * @description Applies MLSD processing to image */ MlsdImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -8501,11 +7407,6 @@ export type components = { */ type: "mlsd_image_processor"; }; - /** - * ModelError - * @constant - */ - ModelError: "not_found"; /** * ModelFormat * @description Storage format of model. @@ -8515,16 +7416,12 @@ export type components = { /** ModelInfo */ ModelInfo: { /** - * Model Name - * @description Info to load submodel + * Key + * @description Key of model as returned by ModelRecordServiceBase.get_model() */ - model_name: string; - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** @description Info to load submodel */ - model_type: components["schemas"]["invokeai__backend__model_management__models__base__ModelType"]; + key: string; /** @description Info to load submodel */ - submodel?: components["schemas"]["SubModelType"] | null; + submodel_type?: components["schemas"]["SubModelType"] | null; }; /** * ModelInstallJob @@ -8550,7 +7447,7 @@ export type components = { * Config Out * @description After successful installation, this will hold the configuration object. */ - config_out?: (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"] | null; + config_out?: (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"] | components["schemas"]["ONNXSDXLConfig"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"] | null; /** * Inplace * @description Leave model in its current location; otherwise install under models directory @@ -8629,7 +7526,7 @@ export type components = { * @description Various hugging face variants on the diffusers format. * @enum {string} */ - ModelRepoVariant: "default" | "fp16" | "fp32" | "onnx" | "openvino" | "flax"; + ModelRepoVariant: "" | "fp16" | "fp32" | "onnx" | "openvino" | "flax"; /** * ModelSummary * @description A short summary of models for UI listing purposes. @@ -8641,9 +7538,9 @@ export type components = { */ key: string; /** @description model type */ - type: components["schemas"]["invokeai__backend__model_manager__config__ModelType"]; + type: components["schemas"]["ModelType"]; /** @description base model */ - base: components["schemas"]["invokeai__backend__model_manager__config__BaseModelType"]; + base: components["schemas"]["BaseModelType"]; /** @description model format */ format: components["schemas"]["ModelFormat"]; /** @@ -8662,6 +7559,26 @@ export type components = { */ tags: string[]; }; + /** + * ModelType + * @description Model type. + * @enum {string} + */ + ModelType: "onnx" | "main" | "vae" | "lora" | "controlnet" | "embedding" | "ip_adapter" | "clip_vision" | "t2i_adapter"; + /** + * ModelVariantType + * @description Variant type. + * @enum {string} + */ + ModelVariantType: "normal" | "inpaint" | "depth"; + /** + * ModelsList + * @description Return list of configs. + */ + ModelsList: { + /** Models */ + models: ((components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"] | components["schemas"]["ONNXSDXLConfig"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"])[]; + }; /** * Multiply Integers * @description Multiplies two numbers @@ -8721,86 +7638,6 @@ export type components = { */ value: string | number; }; - /** - * 2D Noise Image - * @description Creates an image of 2D Noise approximating the desired characteristics - */ - NoiseImage2DInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** - * Noise Type - * @description Desired noise spectral characteristics - * @default White - * @enum {string} - */ - noise_type?: "White" | "Red" | "Blue" | "Green"; - /** - * Width - * @description Desired image width - * @default 512 - */ - width?: number; - /** - * Height - * @description Desired image height - * @default 512 - */ - height?: number; - /** - * Seed - * @description Seed for noise generation - * @default 0 - */ - seed?: number; - /** - * Iterations - * @description Noise approx. iterations - * @default 15 - */ - iterations?: number; - /** - * Blur Threshold - * @description Threshold used in computing noise (lower is better/slower) - * @default 0.2 - */ - blur_threshold?: number; - /** - * Sigma Red - * @description Sigma for strong gaussian blur LPF for red/green - * @default 3 - */ - sigma_red?: number; - /** - * Sigma Blue - * @description Sigma for weak gaussian blur HPF for blue/green - * @default 1 - */ - sigma_blue?: number; - /** - * type - * @default noiseimg_2d - * @constant - */ - type: "noiseimg_2d"; - }; /** * Noise * @description Generates latent noise. @@ -8878,89 +7715,13 @@ export type components = { */ type: "noise_output"; }; - /** - * Noise (Spectral characteristics) - * @description Creates an image of 2D Noise approximating the desired characteristics - */ - NoiseSpectralInvocation: { - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** - * Noise Type - * @description Desired noise spectral characteristics - * @default White - * @enum {string} - */ - noise_type?: "White" | "Red" | "Blue" | "Green"; - /** - * Width - * @description Desired image width - * @default 512 - */ - width?: number; - /** - * Height - * @description Desired image height - * @default 512 - */ - height?: number; - /** - * Seed - * @description Seed for noise generation - * @default 0 - */ - seed?: number; - /** - * Iterations - * @description Noise approx. iterations - * @default 15 - */ - iterations?: number; - /** - * Blur Threshold - * @description Threshold used in computing noise (lower is better/slower) - * @default 0.2 - */ - blur_threshold?: number; - /** - * Sigma Red - * @description Sigma for strong gaussian blur LPF for red/green - * @default 3 - */ - sigma_red?: number; - /** - * Sigma Blue - * @description Sigma for weak gaussian blur HPF for blue/green - * @default 1 - */ - sigma_blue?: number; - /** - * type - * @default noise_spectral - * @constant - */ - type: "noise_spectral"; - }; /** * Normal BAE Processor * @description Applies NormalBae processing to image */ NormalbaeImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -9002,121 +7763,106 @@ export type components = { type: "normalbae_image_processor"; }; /** - * ONNX Latents to Image - * @description Generates an image from latents. + * ONNXSD1Config + * @description Model config for ONNX format models based on sd-1. */ - ONNXLatentsToImageInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; + ONNXSD1Config: { /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. + * Path + * @description filesystem path to the model file or directory */ - id: string; + path: string; /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false + * Name + * @description model name */ - is_intermediate?: boolean; + name: string; /** - * Use Cache - * @description Whether or not to use the cache - * @default true + * Base + * @default sd-1 + * @constant */ - use_cache?: boolean; - /** @description Denoised latents tensor */ - latents?: components["schemas"]["LatentsField"]; - /** @description VAE */ - vae?: components["schemas"]["VaeField"]; + base?: "sd-1"; /** - * type - * @default l2i_onnx + * Type + * @default onnx * @constant */ - type: "l2i_onnx"; - }; - /** - * ONNXModelLoaderOutput - * @description Model loader output - */ - ONNXModelLoaderOutput: { + type?: "onnx"; /** - * UNet - * @description UNet (scheduler, LoRAs) + * Format + * @enum {string} */ - unet?: components["schemas"]["UNetField"]; + format: "onnx" | "olive"; /** - * CLIP - * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count + * Key + * @description unique key for model + * @default */ - clip?: components["schemas"]["ClipField"]; + key?: string; /** - * VAE Decoder - * @description VAE + * Original Hash + * @description original fasthash of model contents */ - vae_decoder?: components["schemas"]["VaeField"]; + original_hash?: string | null; /** - * VAE Encoder - * @description VAE + * Current Hash + * @description current fasthash of model contents */ - vae_encoder?: components["schemas"]["VaeField"]; + current_hash?: string | null; /** - * type - * @default model_loader_output_onnx - * @constant + * Description + * @description human readable description of the model */ - type: "model_loader_output_onnx"; - }; - /** ONNX Prompt (Raw) */ - ONNXPromptInvocation: { - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; + description?: string | null; /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false + * Source + * @description model original source (path, URL or repo_id) */ - is_intermediate?: boolean; + source?: string | null; /** - * Use Cache - * @description Whether or not to use the cache - * @default true + * Last Modified + * @description timestamp for modification time */ - use_cache?: boolean; + last_modified?: number | null; + /** Vae */ + vae?: string | null; + /** @default normal */ + variant?: components["schemas"]["ModelVariantType"]; + /** @default epsilon */ + prediction_type?: components["schemas"]["SchedulerPredictionType"]; /** - * Prompt - * @description Raw prompt text (no parsing) - * @default + * Upcast Attention + * @default false */ - prompt?: string; - /** @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count */ - clip?: components["schemas"]["ClipField"]; + upcast_attention?: boolean; /** - * type - * @default prompt_onnx - * @constant + * Ztsnr Training + * @default false */ - type: "prompt_onnx"; + ztsnr_training?: boolean; }; /** - * ONNXSD1Config - * @description Model config for ONNX format models based on sd-1. + * ONNXSD2Config + * @description Model config for ONNX format models based on sd-2. */ - ONNXSD1Config: { - /** Path */ + ONNXSD2Config: { + /** + * Path + * @description filesystem path to the model file or directory + */ path: string; - /** Name */ + /** + * Name + * @description model name + */ name: string; /** * Base - * @default sd-1 + * @default sd-2 * @constant */ - base?: "sd-1"; + base?: "sd-2"; /** * Type * @default onnx @@ -9144,45 +7890,59 @@ export type components = { * @description current fasthash of model contents */ current_hash?: string | null; - /** Description */ + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** * Source - * @description Model download source (URL or repo_id) + * @description model original source (path, URL or repo_id) */ source?: string | null; + /** + * Last Modified + * @description timestamp for modification time + */ + last_modified?: number | null; /** Vae */ vae?: string | null; /** @default normal */ - variant?: components["schemas"]["invokeai__backend__model_manager__config__ModelVariantType"]; + variant?: components["schemas"]["ModelVariantType"]; + /** @default v_prediction */ + prediction_type?: components["schemas"]["SchedulerPredictionType"]; /** - * Ztsnr Training - * @default false + * Upcast Attention + * @default true */ - ztsnr_training?: boolean; - /** @default epsilon */ - prediction_type?: components["schemas"]["invokeai__backend__model_manager__config__SchedulerPredictionType"]; + upcast_attention?: boolean; /** - * Upcast Attention + * Ztsnr Training * @default false */ - upcast_attention?: boolean; + ztsnr_training?: boolean; }; /** - * ONNXSD2Config - * @description Model config for ONNX format models based on sd-2. + * ONNXSDXLConfig + * @description Model config for ONNX format models based on sdxl. */ - ONNXSD2Config: { - /** Path */ + ONNXSDXLConfig: { + /** + * Path + * @description filesystem path to the model file or directory + */ path: string; - /** Name */ + /** + * Name + * @description model name + */ name: string; /** * Base - * @default sd-2 + * @default sdxl * @constant */ - base?: "sd-2"; + base?: "sdxl"; /** * Type * @default onnx @@ -9210,189 +7970,37 @@ export type components = { * @description current fasthash of model contents */ current_hash?: string | null; - /** Description */ + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** * Source - * @description Model download source (URL or repo_id) + * @description model original source (path, URL or repo_id) */ source?: string | null; + /** + * Last Modified + * @description timestamp for modification time + */ + last_modified?: number | null; /** Vae */ vae?: string | null; /** @default normal */ - variant?: components["schemas"]["invokeai__backend__model_manager__config__ModelVariantType"]; - /** - * Ztsnr Training - * @default false - */ - ztsnr_training?: boolean; + variant?: components["schemas"]["ModelVariantType"]; /** @default v_prediction */ - prediction_type?: components["schemas"]["invokeai__backend__model_manager__config__SchedulerPredictionType"]; + prediction_type?: components["schemas"]["SchedulerPredictionType"]; /** * Upcast Attention - * @default true - */ - upcast_attention?: boolean; - }; - /** ONNXStableDiffusion1ModelConfig */ - ONNXStableDiffusion1ModelConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default onnx - * @constant - */ - model_type: "onnx"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - /** - * Model Format - * @constant - */ - model_format: "onnx"; - error?: components["schemas"]["ModelError"] | null; - variant: components["schemas"]["invokeai__backend__model_management__models__base__ModelVariantType"]; - }; - /** ONNXStableDiffusion2ModelConfig */ - ONNXStableDiffusion2ModelConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default onnx - * @constant - */ - model_type: "onnx"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - /** - * Model Format - * @constant - */ - model_format: "onnx"; - error?: components["schemas"]["ModelError"] | null; - variant: components["schemas"]["invokeai__backend__model_management__models__base__ModelVariantType"]; - prediction_type: components["schemas"]["invokeai__backend__model_management__models__base__SchedulerPredictionType"]; - /** Upcast Attention */ - upcast_attention: boolean; - }; - /** - * ONNX Text to Latents - * @description Generates latents from conditionings. - */ - ONNXTextToLatentsInvocation: { - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. * @default false */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description Positive conditioning tensor */ - positive_conditioning?: components["schemas"]["ConditioningField"]; - /** @description Negative conditioning tensor */ - negative_conditioning?: components["schemas"]["ConditioningField"]; - /** @description Noise tensor */ - noise?: components["schemas"]["LatentsField"]; - /** - * Steps - * @description Number of steps to run - * @default 10 - */ - steps?: number; - /** - * Cfg Scale - * @description Classifier-Free Guidance scale - * @default 7.5 - */ - cfg_scale?: number | number[]; - /** - * Scheduler - * @description Scheduler to use during inference - * @default euler - * @enum {string} - */ - scheduler?: "ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "lcm"; - /** - * Precision - * @description Precision to use - * @default tensor(float16) - * @enum {string} - */ - precision?: "tensor(bool)" | "tensor(int8)" | "tensor(uint8)" | "tensor(int16)" | "tensor(uint16)" | "tensor(int32)" | "tensor(uint32)" | "tensor(int64)" | "tensor(uint64)" | "tensor(float16)" | "tensor(float)" | "tensor(double)"; - /** @description UNet (scheduler, LoRAs) */ - unet?: components["schemas"]["UNetField"]; - /** - * Control - * @description ControlNet(s) to apply - */ - control?: components["schemas"]["ControlField"] | components["schemas"]["ControlField"][]; - /** - * type - * @default t2l_onnx - * @constant - */ - type: "t2l_onnx"; - }; - /** - * Offset Latents - * @description Offsets a latents tensor by a given percentage of height/width. - */ - OffsetLatentsInvocation: { - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; + upcast_attention?: boolean; /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. + * Ztsnr Training * @default false */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description Latents tensor */ - latents?: components["schemas"]["LatentsField"]; - /** - * X Offset - * @description Approx percentage to offset (H) - * @default 0.5 - */ - x_offset?: number; - /** - * Y Offset - * @description Approx percentage to offset (V) - * @default 0.5 - */ - y_offset?: number; - /** - * type - * @default offset_latents - * @constant - */ - type: "offset_latents"; + ztsnr_training?: boolean; }; /** OffsetPaginatedResults[BoardDTO] */ OffsetPaginatedResults_BoardDTO_: { @@ -9440,52 +8048,6 @@ export type components = { */ items: components["schemas"]["ImageDTO"][]; }; - /** - * OnnxModelField - * @description Onnx model field - */ - OnnxModelField: { - /** - * Model Name - * @description Name of the model - */ - model_name: string; - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** @description Model Type */ - model_type: components["schemas"]["invokeai__backend__model_management__models__base__ModelType"]; - }; - /** - * ONNX Main Model - * @description Loads a main model, outputting its submodels. - */ - OnnxModelLoaderInvocation: { - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description ONNX Main model (UNet, VAE, CLIP) to load */ - model: components["schemas"]["OnnxModelField"]; - /** - * type - * @default onnx_model_loader - * @constant - */ - type: "onnx_model_loader"; - }; /** PaginatedResults[ModelSummary] */ PaginatedResults_ModelSummary_: { /** @@ -9591,6 +8153,8 @@ export type components = { * @description Applies PIDI processing to image */ PidiImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -10443,6 +9007,8 @@ export type components = { * @description Saves an image. Unlike an image primitive, this invocation stores a copy of the image. */ SaveImageInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -10464,8 +9030,6 @@ export type components = { use_cache?: boolean; /** @description The image to process */ image?: components["schemas"]["ImageField"]; - /** @description The board to save the image to */ - board?: components["schemas"]["BoardField"]; /** * type * @default save_image @@ -10573,6 +9137,12 @@ export type components = { */ type: "scheduler_output"; }; + /** + * SchedulerPredictionType + * @description Scheduler prediction type. + * @enum {string} + */ + SchedulerPredictionType: "epsilon" | "v_prediction" | "sample"; /** * Seamless * @description Applies the seamless transformation to the Model UNet and VAE. @@ -10651,6 +9221,8 @@ export type components = { * @description Applies segment anything processing to image */ SegmentAnythingProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -10890,108 +9462,8 @@ export type components = { total: number; }; /** - * Shadows/Highlights/Midtones - * @description Extract a Shadows/Highlights/Midtones mask from an image - */ - ShadowsHighlightsMidtonesMaskInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description Image from which to extract mask */ - image?: components["schemas"]["ImageField"]; - /** - * Invert Output - * @description Off: white on black / On: black on white - * @default true - */ - invert_output?: boolean; - /** - * Highlight Threshold - * @description Threshold beyond which mask values will be at extremum - * @default 0.75 - */ - highlight_threshold?: number; - /** - * Upper Mid Threshold - * @description Threshold to which to extend mask border by 0..1 gradient - * @default 0.7 - */ - upper_mid_threshold?: number; - /** - * Lower Mid Threshold - * @description Threshold to which to extend mask border by 0..1 gradient - * @default 0.3 - */ - lower_mid_threshold?: number; - /** - * Shadow Threshold - * @description Threshold beyond which mask values will be at extremum - * @default 0.25 - */ - shadow_threshold?: number; - /** - * Mask Expand Or Contract - * @description Pixels to grow (or shrink) the mask areas - * @default 0 - */ - mask_expand_or_contract?: number; - /** - * Mask Blur - * @description Gaussian blur radius to apply to the masks - * @default 0 - */ - mask_blur?: number; - /** - * type - * @default shmmask - * @constant - */ - type: "shmmask"; - }; - /** ShadowsHighlightsMidtonesMasksOutput */ - ShadowsHighlightsMidtonesMasksOutput: { - /** @description Soft-edged highlights mask */ - highlights_mask?: components["schemas"]["ImageField"]; - /** @description Soft-edged midtones mask */ - midtones_mask?: components["schemas"]["ImageField"]; - /** @description Soft-edged shadows mask */ - shadows_mask?: components["schemas"]["ImageField"]; - /** - * Width - * @description Width of the input/outputs - */ - width: number; - /** - * Height - * @description Height of the input/outputs - */ - height: number; - /** - * type - * @default shmmask_output - * @constant - */ - type: "shmmask_output"; - }; - /** - * Show Image - * @description Displays a provided image using the OS image viewer, and passes it forward in the pipeline. + * Show Image + * @description Displays a provided image using the OS image viewer, and passes it forward in the pipeline. */ ShowImageInvocation: { /** @@ -11020,162 +9492,6 @@ export type components = { */ type: "show_image"; }; - /** StableDiffusion1ModelCheckpointConfig */ - StableDiffusion1ModelCheckpointConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default main - * @constant - */ - model_type: "main"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - /** - * Model Format - * @constant - */ - model_format: "checkpoint"; - error?: components["schemas"]["ModelError"] | null; - /** Vae */ - vae?: string | null; - /** Config */ - config: string; - variant: components["schemas"]["invokeai__backend__model_management__models__base__ModelVariantType"]; - }; - /** StableDiffusion1ModelDiffusersConfig */ - StableDiffusion1ModelDiffusersConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default main - * @constant - */ - model_type: "main"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - /** - * Model Format - * @constant - */ - model_format: "diffusers"; - error?: components["schemas"]["ModelError"] | null; - /** Vae */ - vae?: string | null; - variant: components["schemas"]["invokeai__backend__model_management__models__base__ModelVariantType"]; - }; - /** StableDiffusion2ModelCheckpointConfig */ - StableDiffusion2ModelCheckpointConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default main - * @constant - */ - model_type: "main"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - /** - * Model Format - * @constant - */ - model_format: "checkpoint"; - error?: components["schemas"]["ModelError"] | null; - /** Vae */ - vae?: string | null; - /** Config */ - config: string; - variant: components["schemas"]["invokeai__backend__model_management__models__base__ModelVariantType"]; - }; - /** StableDiffusion2ModelDiffusersConfig */ - StableDiffusion2ModelDiffusersConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default main - * @constant - */ - model_type: "main"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - /** - * Model Format - * @constant - */ - model_format: "diffusers"; - error?: components["schemas"]["ModelError"] | null; - /** Vae */ - vae?: string | null; - variant: components["schemas"]["invokeai__backend__model_management__models__base__ModelVariantType"]; - }; - /** StableDiffusionXLModelCheckpointConfig */ - StableDiffusionXLModelCheckpointConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default main - * @constant - */ - model_type: "main"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - /** - * Model Format - * @constant - */ - model_format: "checkpoint"; - error?: components["schemas"]["ModelError"] | null; - /** Vae */ - vae?: string | null; - /** Config */ - config: string; - variant: components["schemas"]["invokeai__backend__model_management__models__base__ModelVariantType"]; - }; - /** StableDiffusionXLModelDiffusersConfig */ - StableDiffusionXLModelDiffusersConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default main - * @constant - */ - model_type: "main"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - /** - * Model Format - * @constant - */ - model_format: "diffusers"; - error?: components["schemas"]["ModelError"] | null; - /** Vae */ - vae?: string | null; - variant: components["schemas"]["invokeai__backend__model_management__models__base__ModelVariantType"]; - }; /** * Step Param Easing * @description Experimental per-step parameter easing for denoising steps @@ -11631,6 +9947,7 @@ export type components = { }; /** * SubModelType + * @description Submodel type. * @enum {string} */ SubModelType: "unet" | "text_encoder" | "text_encoder_2" | "tokenizer" | "tokenizer_2" | "vae" | "vae_decoder" | "vae_encoder" | "scheduler" | "safety_checker"; @@ -11768,37 +10085,13 @@ export type components = { */ type: "t2i_adapter"; }; - /** T2IAdapterModelDiffusersConfig */ - T2IAdapterModelDiffusersConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default t2i_adapter - * @constant - */ - model_type: "t2i_adapter"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - /** - * Model Format - * @constant - */ - model_format: "diffusers"; - error?: components["schemas"]["ModelError"] | null; - }; /** T2IAdapterModelField */ T2IAdapterModelField: { /** - * Model Name - * @description Name of the T2I-Adapter model + * Key + * @description Model record key for the T2I-Adapter model */ - model_name: string; - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; + key: string; }; /** T2IAdapterOutput */ T2IAdapterOutput: { @@ -11819,11 +10112,18 @@ export type components = { * @description Model config for T2I. */ T2IConfig: { - /** Path */ + /** + * Path + * @description filesystem path to the model file or directory + */ path: string; - /** Name */ + /** + * Name + * @description model name + */ name: string; - base: components["schemas"]["invokeai__backend__model_manager__config__BaseModelType"]; + /** @description base model */ + base: components["schemas"]["BaseModelType"]; /** * Type * @default t2i_adapter @@ -11851,13 +10151,21 @@ export type components = { * @description current fasthash of model contents */ current_hash?: string | null; - /** Description */ + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** * Source - * @description Model download source (URL or repo_id) + * @description model original source (path, URL or repo_id) */ source?: string | null; + /** + * Last Modified + * @description timestamp for modification time + */ + last_modified?: number | null; }; /** TBLR */ TBLR: { @@ -11871,95 +10179,79 @@ export type components = { right: number; }; /** - * Text Mask - * @description Creates a 2D rendering of a text mask from a given font + * TextualInversionConfig + * @description Model config for textual inversion embeddings. */ - TextMaskInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; + TextualInversionConfig: { /** - * Width - * @description The width of the desired mask - * @default 512 + * Path + * @description filesystem path to the model file or directory */ - width?: number; + path: string; /** - * Height - * @description The height of the desired mask - * @default 512 + * Name + * @description model name */ - height?: number; + name: string; + /** @description base model */ + base: components["schemas"]["BaseModelType"]; /** - * Text - * @description The text to render - * @default + * Type + * @default embedding + * @constant */ - text?: string; + type?: "embedding"; /** - * Font - * @description Path to a FreeType-supported TTF/OTF font file - * @default + * Format + * @enum {string} */ - font?: string; + format: "embedding_file" | "embedding_folder"; /** - * Size - * @description Desired point size of text to use - * @default 64 + * Key + * @description unique key for model + * @default */ - size?: number; + key?: string; /** - * Angle - * @description Angle of rotation to apply to the text - * @default 0 + * Original Hash + * @description original fasthash of model contents */ - angle?: number; + original_hash?: string | null; /** - * X Offset - * @description x-offset for text rendering - * @default 24 + * Current Hash + * @description current fasthash of model contents */ - x_offset?: number; + current_hash?: string | null; /** - * Y Offset - * @description y-offset for text rendering - * @default 36 + * Description + * @description human readable description of the model */ - y_offset?: number; + description?: string | null; /** - * Invert - * @description Whether to invert color of the output - * @default false + * Source + * @description model original source (path, URL or repo_id) */ - invert?: boolean; + source?: string | null; /** - * type - * @default text_mask - * @constant + * Last Modified + * @description timestamp for modification time */ - type: "text_mask"; + last_modified?: number | null; + }; + /** Tile */ + Tile: { + /** @description The coordinates of this tile relative to its parent image. */ + coords: components["schemas"]["TBLR"]; + /** @description The amount of overlap with adjacent tiles on each side of this tile. */ + overlap: components["schemas"]["TBLR"]; }; /** - * Text to Mask Advanced (Clipseg) - * @description Uses the Clipseg model to generate an image mask from a text prompt + * Tile Resample Processor + * @description Tile resampler processor */ - TextToMaskClipsegAdvancedInvocation: { + TileResamplerProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -11979,236 +10271,7 @@ export type components = { * @default true */ use_cache?: boolean; - /** @description The image from which to create a mask */ - image?: components["schemas"]["ImageField"]; - /** - * Invert Output - * @description Off: white on black / On: black on white - * @default true - */ - invert_output?: boolean; - /** - * Prompt 1 - * @description First prompt with which to create a mask - */ - prompt_1?: string; - /** - * Prompt 2 - * @description Second prompt with which to create a mask (optional) - */ - prompt_2?: string; - /** - * Prompt 3 - * @description Third prompt with which to create a mask (optional) - */ - prompt_3?: string; - /** - * Prompt 4 - * @description Fourth prompt with which to create a mask (optional) - */ - prompt_4?: string; - /** - * Combine - * @description How to combine the results - * @default or - * @enum {string} - */ - combine?: "or" | "and" | "none (rgba multiplex)"; - /** - * Smoothing - * @description Radius of blur to apply before thresholding - * @default 4 - */ - smoothing?: number; - /** - * Subject Threshold - * @description Threshold above which is considered the subject - * @default 1 - */ - subject_threshold?: number; - /** - * Background Threshold - * @description Threshold below which is considered the background - * @default 0 - */ - background_threshold?: number; - /** - * type - * @default txt2mask_clipseg_adv - * @constant - */ - type: "txt2mask_clipseg_adv"; - }; - /** - * Text to Mask (Clipseg) - * @description Uses the Clipseg model to generate an image mask from a text prompt - */ - TextToMaskClipsegInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The image from which to create a mask */ - image?: components["schemas"]["ImageField"]; - /** - * Invert Output - * @description Off: white on black / On: black on white - * @default true - */ - invert_output?: boolean; - /** - * Prompt - * @description The prompt with which to create a mask - */ - prompt?: string; - /** - * Smoothing - * @description Radius of blur to apply before thresholding - * @default 4 - */ - smoothing?: number; - /** - * Subject Threshold - * @description Threshold above which is considered the subject - * @default 0.4 - */ - subject_threshold?: number; - /** - * Background Threshold - * @description Threshold below which is considered the background - * @default 0.4 - */ - background_threshold?: number; - /** - * Mask Expand Or Contract - * @description Pixels by which to grow (or shrink) mask after thresholding - * @default 0 - */ - mask_expand_or_contract?: number; - /** - * Mask Blur - * @description Radius of blur to apply after thresholding - * @default 0 - */ - mask_blur?: number; - /** - * type - * @default txt2mask_clipseg - * @constant - */ - type: "txt2mask_clipseg"; - }; - /** - * TextualInversionConfig - * @description Model config for textual inversion embeddings. - */ - TextualInversionConfig: { - /** Path */ - path: string; - /** Name */ - name: string; - base: components["schemas"]["invokeai__backend__model_manager__config__BaseModelType"]; - /** - * Type - * @default embedding - * @constant - */ - type?: "embedding"; - /** - * Format - * @enum {string} - */ - format: "embedding_file" | "embedding_folder"; - /** - * Key - * @description unique key for model - * @default - */ - key?: string; - /** - * Original Hash - * @description original fasthash of model contents - */ - original_hash?: string | null; - /** - * Current Hash - * @description current fasthash of model contents - */ - current_hash?: string | null; - /** Description */ - description?: string | null; - /** - * Source - * @description Model download source (URL or repo_id) - */ - source?: string | null; - }; - /** TextualInversionModelConfig */ - TextualInversionModelConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default embedding - * @constant - */ - model_type: "embedding"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - /** Model Format */ - model_format: null; - error?: components["schemas"]["ModelError"] | null; - }; - /** Tile */ - Tile: { - /** @description The coordinates of this tile relative to its parent image. */ - coords: components["schemas"]["TBLR"]; - /** @description The amount of overlap with adjacent tiles on each side of this tile. */ - overlap: components["schemas"]["TBLR"]; - }; - /** - * Tile Resample Processor - * @description Tile resampler processor - */ - TileResamplerProcessorInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The image to process */ + /** @description The image to process */ image?: components["schemas"]["ImageField"]; /** * Down Sampling Rate @@ -12339,7 +10402,7 @@ export type components = { }; /** * UNetOutput - * @description Base class for invocations that output a UNet field + * @description Base class for invocations that output a UNet field. */ UNetOutput: { /** @@ -12378,6 +10441,8 @@ export type components = { * @description Applies an unsharp mask filter to an image */ UnsharpMaskInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -12437,12 +10502,10 @@ export type components = { */ VAEModelField: { /** - * Model Name - * @description Name of the model + * Key + * @description Model's key */ - model_name: string; - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; + key: string; }; /** * VAEOutput @@ -12466,11 +10529,18 @@ export type components = { * @description Model config for standalone VAE models. */ VaeCheckpointConfig: { - /** Path */ + /** + * Path + * @description filesystem path to the model file or directory + */ path: string; - /** Name */ + /** + * Name + * @description model name + */ name: string; - base: components["schemas"]["invokeai__backend__model_manager__config__BaseModelType"]; + /** @description base model */ + base: components["schemas"]["BaseModelType"]; /** * Type * @default vae @@ -12499,24 +10569,39 @@ export type components = { * @description current fasthash of model contents */ current_hash?: string | null; - /** Description */ + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** * Source - * @description Model download source (URL or repo_id) + * @description model original source (path, URL or repo_id) */ source?: string | null; + /** + * Last Modified + * @description timestamp for modification time + */ + last_modified?: number | null; }; /** * VaeDiffusersConfig * @description Model config for standalone VAE models (diffusers version). */ VaeDiffusersConfig: { - /** Path */ + /** + * Path + * @description filesystem path to the model file or directory + */ path: string; - /** Name */ + /** + * Name + * @description model name + */ name: string; - base: components["schemas"]["invokeai__backend__model_manager__config__BaseModelType"]; + /** @description base model */ + base: components["schemas"]["BaseModelType"]; /** * Type * @default vae @@ -12545,13 +10630,21 @@ export type components = { * @description current fasthash of model contents */ current_hash?: string | null; - /** Description */ + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** * Source - * @description Model download source (URL or repo_id) + * @description model original source (path, URL or repo_id) */ source?: string | null; + /** + * Last Modified + * @description timestamp for modification time + */ + last_modified?: number | null; }; /** VaeField */ VaeField: { @@ -12597,29 +10690,6 @@ export type components = { */ type: "vae_loader"; }; - /** VaeModelConfig */ - VaeModelConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default vae - * @constant - */ - model_type: "vae"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - model_format: components["schemas"]["VaeModelFormat"]; - error?: components["schemas"]["ModelError"] | null; - }; - /** - * VaeModelFormat - * @enum {string} - */ - VaeModelFormat: "checkpoint" | "diffusers"; /** ValidationError */ ValidationError: { /** Location */ @@ -12846,6 +10916,8 @@ export type components = { * @description Applies Zoe depth processing to image */ ZoeDepthImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -12874,63 +10946,6 @@ export type components = { */ type: "zoe_depth_image_processor"; }; - /** - * ModelsList - * @description Return list of configs. - */ - invokeai__app__api__routers__model_records__ModelsList: { - /** Models */ - models: ((components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"])[]; - }; - /** ModelsList */ - invokeai__app__api__routers__models__ModelsList: { - /** Models */ - models: (components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["IPAdapterModelInvokeAIConfig"] | components["schemas"]["CLIPVisionModelDiffusersConfig"] | components["schemas"]["T2IAdapterModelDiffusersConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"])[]; - }; - /** - * BaseModelType - * @enum {string} - */ - invokeai__backend__model_management__models__base__BaseModelType: "any" | "sd-1" | "sd-2" | "sdxl" | "sdxl-refiner"; - /** - * ModelType - * @enum {string} - */ - invokeai__backend__model_management__models__base__ModelType: "onnx" | "main" | "vae" | "lora" | "controlnet" | "embedding" | "ip_adapter" | "clip_vision" | "t2i_adapter"; - /** - * ModelVariantType - * @enum {string} - */ - invokeai__backend__model_management__models__base__ModelVariantType: "normal" | "inpaint" | "depth"; - /** - * SchedulerPredictionType - * @enum {string} - */ - invokeai__backend__model_management__models__base__SchedulerPredictionType: "epsilon" | "v_prediction" | "sample"; - /** - * BaseModelType - * @description Base model type. - * @enum {string} - */ - invokeai__backend__model_manager__config__BaseModelType: "any" | "sd-1" | "sd-2" | "sdxl" | "sdxl-refiner"; - /** - * ModelType - * @description Model type. - * @enum {string} - */ - invokeai__backend__model_manager__config__ModelType: "onnx" | "main" | "vae" | "lora" | "controlnet" | "embedding" | "ip_adapter" | "clip_vision" | "t2i_adapter"; - /** - * ModelVariantType - * @description Variant type. - * @enum {string} - */ - invokeai__backend__model_manager__config__ModelVariantType: "normal" | "inpaint" | "depth"; - /** - * SchedulerPredictionType - * @description Scheduler prediction type. - * @enum {string} - */ - invokeai__backend__model_manager__config__SchedulerPredictionType: "epsilon" | "v_prediction" | "sample"; /** * Classification * @description The classification of an Invocation. @@ -13098,53 +11113,65 @@ export type components = { */ UIType: "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_MainModel" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict"; /** - * T2IAdapterModelFormat + * VaeModelFormat * @description An enumeration. * @enum {string} */ - T2IAdapterModelFormat: "diffusers"; + VaeModelFormat: "checkpoint" | "diffusers"; /** - * IPAdapterModelFormat + * StableDiffusion2ModelFormat * @description An enumeration. * @enum {string} */ - IPAdapterModelFormat: "invokeai"; + StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; /** - * StableDiffusionXLModelFormat + * CLIPVisionModelFormat + * @description An enumeration. + * @enum {string} + */ + CLIPVisionModelFormat: "diffusers"; + /** + * StableDiffusionXLModelFormat * @description An enumeration. * @enum {string} */ StableDiffusionXLModelFormat: "checkpoint" | "diffusers"; /** - * StableDiffusionOnnxModelFormat + * StableDiffusion1ModelFormat * @description An enumeration. * @enum {string} */ - StableDiffusionOnnxModelFormat: "olive" | "onnx"; + StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; /** - * CLIPVisionModelFormat + * ControlNetModelFormat * @description An enumeration. * @enum {string} */ - CLIPVisionModelFormat: "diffusers"; + ControlNetModelFormat: "checkpoint" | "diffusers"; /** - * StableDiffusion2ModelFormat + * StableDiffusionOnnxModelFormat * @description An enumeration. * @enum {string} */ - StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; + StableDiffusionOnnxModelFormat: "olive" | "onnx"; /** - * StableDiffusion1ModelFormat + * T2IAdapterModelFormat * @description An enumeration. * @enum {string} */ - StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; + T2IAdapterModelFormat: "diffusers"; /** - * ControlNetModelFormat + * LoRAModelFormat * @description An enumeration. * @enum {string} */ - ControlNetModelFormat: "checkpoint" | "diffusers"; + LoRAModelFormat: "lycoris" | "diffusers"; + /** + * IPAdapterModelFormat + * @description An enumeration. + * @enum {string} + */ + IPAdapterModelFormat: "invokeai"; }; responses: never; parameters: never; @@ -13214,328 +11241,6 @@ export type operations = { }; }; }; - /** - * List Models - * @description Gets a list of models - */ - list_models: { - parameters: { - query?: { - /** @description Base models to include */ - base_models?: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"][] | null; - /** @description The type of model to get */ - model_type?: components["schemas"]["invokeai__backend__model_management__models__base__ModelType"] | null; - }; - }; - responses: { - /** @description Successful Response */ - 200: { - content: { - "application/json": components["schemas"]["invokeai__app__api__routers__models__ModelsList"]; - }; - }; - /** @description Validation Error */ - 422: { - content: { - "application/json": components["schemas"]["HTTPValidationError"]; - }; - }; - }; - }; - /** - * Delete Model - * @description Delete Model - */ - del_model: { - parameters: { - path: { - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** @description The type of model */ - model_type: components["schemas"]["invokeai__backend__model_management__models__base__ModelType"]; - /** @description model name */ - model_name: string; - }; - }; - responses: { - /** @description Model deleted successfully */ - 204: { - content: never; - }; - /** @description Model not found */ - 404: { - content: never; - }; - /** @description Validation Error */ - 422: { - content: { - "application/json": components["schemas"]["HTTPValidationError"]; - }; - }; - }; - }; - /** - * Update Model - * @description Update model contents with a new config. If the model name or base fields are changed, then the model is renamed. - */ - update_model: { - parameters: { - path: { - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** @description The type of model */ - model_type: components["schemas"]["invokeai__backend__model_management__models__base__ModelType"]; - /** @description model name */ - model_name: string; - }; - }; - requestBody: { - content: { - "application/json": components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["IPAdapterModelInvokeAIConfig"] | components["schemas"]["CLIPVisionModelDiffusersConfig"] | components["schemas"]["T2IAdapterModelDiffusersConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"]; - }; - }; - responses: { - /** @description The model was updated successfully */ - 200: { - content: { - "application/json": components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["IPAdapterModelInvokeAIConfig"] | components["schemas"]["CLIPVisionModelDiffusersConfig"] | components["schemas"]["T2IAdapterModelDiffusersConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"]; - }; - }; - /** @description Bad request */ - 400: { - content: never; - }; - /** @description The model could not be found */ - 404: { - content: never; - }; - /** @description There is already a model corresponding to the new name */ - 409: { - content: never; - }; - /** @description Validation Error */ - 422: { - content: { - "application/json": components["schemas"]["HTTPValidationError"]; - }; - }; - }; - }; - /** - * Import Model - * @description Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically - */ - import_model: { - requestBody: { - content: { - "application/json": components["schemas"]["Body_import_model"]; - }; - }; - responses: { - /** @description The model imported successfully */ - 201: { - content: { - "application/json": components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["IPAdapterModelInvokeAIConfig"] | components["schemas"]["CLIPVisionModelDiffusersConfig"] | components["schemas"]["T2IAdapterModelDiffusersConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"]; - }; - }; - /** @description The model could not be found */ - 404: { - content: never; - }; - /** @description There is already a model corresponding to this path or repo_id */ - 409: { - content: never; - }; - /** @description Unrecognized file/folder format */ - 415: { - content: never; - }; - /** @description Validation Error */ - 422: { - content: { - "application/json": components["schemas"]["HTTPValidationError"]; - }; - }; - /** @description The model appeared to import successfully, but could not be found in the model manager */ - 424: { - content: never; - }; - }; - }; - /** - * Add Model - * @description Add a model using the configuration information appropriate for its type. Only local models can be added by path - */ - add_model: { - requestBody: { - content: { - "application/json": components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["IPAdapterModelInvokeAIConfig"] | components["schemas"]["CLIPVisionModelDiffusersConfig"] | components["schemas"]["T2IAdapterModelDiffusersConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"]; - }; - }; - responses: { - /** @description The model added successfully */ - 201: { - content: { - "application/json": components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["IPAdapterModelInvokeAIConfig"] | components["schemas"]["CLIPVisionModelDiffusersConfig"] | components["schemas"]["T2IAdapterModelDiffusersConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"]; - }; - }; - /** @description The model could not be found */ - 404: { - content: never; - }; - /** @description There is already a model corresponding to this path or repo_id */ - 409: { - content: never; - }; - /** @description Validation Error */ - 422: { - content: { - "application/json": components["schemas"]["HTTPValidationError"]; - }; - }; - /** @description The model appeared to add successfully, but could not be found in the model manager */ - 424: { - content: never; - }; - }; - }; - /** - * Convert Model - * @description Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none. - */ - convert_model: { - parameters: { - query?: { - /** @description Save the converted model to the designated directory */ - convert_dest_directory?: string | null; - }; - path: { - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** @description The type of model */ - model_type: components["schemas"]["invokeai__backend__model_management__models__base__ModelType"]; - /** @description model name */ - model_name: string; - }; - }; - responses: { - /** @description Model converted successfully */ - 200: { - content: { - "application/json": components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["IPAdapterModelInvokeAIConfig"] | components["schemas"]["CLIPVisionModelDiffusersConfig"] | components["schemas"]["T2IAdapterModelDiffusersConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"]; - }; - }; - /** @description Bad request */ - 400: { - content: never; - }; - /** @description Model not found */ - 404: { - content: never; - }; - /** @description Validation Error */ - 422: { - content: { - "application/json": components["schemas"]["HTTPValidationError"]; - }; - }; - }; - }; - /** Search For Models */ - search_for_models: { - parameters: { - query: { - /** @description Directory path to search for models */ - search_path: string; - }; - }; - responses: { - /** @description Directory searched successfully */ - 200: { - content: { - "application/json": string[]; - }; - }; - /** @description Invalid directory path */ - 404: { - content: never; - }; - /** @description Validation Error */ - 422: { - content: { - "application/json": components["schemas"]["HTTPValidationError"]; - }; - }; - }; - }; - /** - * List Ckpt Configs - * @description Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT. - */ - list_ckpt_configs: { - responses: { - /** @description paths retrieved successfully */ - 200: { - content: { - "application/json": string[]; - }; - }; - }; - }; - /** - * Sync To Config - * @description Call after making changes to models.yaml, autoimport directories or models directory to synchronize - * in-memory data structures with disk data structures. - */ - sync_to_config: { - responses: { - /** @description synchronization successful */ - 201: { - content: { - "application/json": boolean; - }; - }; - }; - }; - /** - * Merge Models - * @description Convert a checkpoint model into a diffusers model - */ - merge_models: { - parameters: { - path: { - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - }; - }; - requestBody: { - content: { - "application/json": components["schemas"]["Body_merge_models"]; - }; - }; - responses: { - /** @description Model converted successfully */ - 200: { - content: { - "application/json": components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["IPAdapterModelInvokeAIConfig"] | components["schemas"]["CLIPVisionModelDiffusersConfig"] | components["schemas"]["T2IAdapterModelDiffusersConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"]; - }; - }; - /** @description Incompatible models */ - 400: { - content: never; - }; - /** @description One or more models not found */ - 404: { - content: never; - }; - /** @description Validation Error */ - 422: { - content: { - "application/json": components["schemas"]["HTTPValidationError"]; - }; - }; - }; - }; /** * List Model Records * @description Get a list of models. @@ -13544,9 +11249,9 @@ export type operations = { parameters: { query?: { /** @description Base models to include */ - base_models?: components["schemas"]["invokeai__backend__model_manager__config__BaseModelType"][] | null; + base_models?: components["schemas"]["BaseModelType"][] | null; /** @description The type of model to get */ - model_type?: components["schemas"]["invokeai__backend__model_manager__config__ModelType"] | null; + model_type?: components["schemas"]["ModelType"] | null; /** @description Exact match on the name of the model */ model_name?: string | null; /** @description Exact match on the format of the model (e.g. 'diffusers') */ @@ -13557,7 +11262,7 @@ export type operations = { /** @description Successful Response */ 200: { content: { - "application/json": components["schemas"]["invokeai__app__api__routers__model_records__ModelsList"]; + "application/json": components["schemas"]["ModelsList"]; }; }; /** @description Validation Error */ @@ -13580,10 +11285,10 @@ export type operations = { }; }; responses: { - /** @description Success */ + /** @description The model configuration was retrieved successfully */ 200: { content: { - "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; + "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"] | components["schemas"]["ONNXSDXLConfig"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; }; }; /** @description Bad request */ @@ -13646,14 +11351,26 @@ export type operations = { }; requestBody: { content: { - "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; + /** + * @example { + * "path": "/path/to/model", + * "name": "model_name", + * "base": "sd-1", + * "type": "main", + * "format": "checkpoint", + * "config": "configs/stable-diffusion/v1-inference.yaml", + * "description": "Model description", + * "variant": "normal" + * } + */ + "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"] | components["schemas"]["ONNXSDXLConfig"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; }; }; responses: { /** @description The model was updated successfully */ 200: { content: { - "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; + "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"] | components["schemas"]["ONNXSDXLConfig"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; }; }; /** @description Bad request */ @@ -13718,7 +11435,7 @@ export type operations = { }; }; responses: { - /** @description Success */ + /** @description The model metadata was retrieved successfully */ 200: { content: { "application/json": (components["schemas"]["BaseMetadata"] | components["schemas"]["HuggingFaceMetadata"] | components["schemas"]["CivitaiMetadata"]) | null; @@ -13769,7 +11486,7 @@ export type operations = { /** @description Successful Response */ 200: { content: { - "application/json": components["schemas"]["invokeai__app__api__routers__model_records__ModelsList"]; + "application/json": components["schemas"]["ModelsList"]; }; }; /** @description Validation Error */ @@ -13787,14 +11504,26 @@ export type operations = { add_model_record: { requestBody: { content: { - "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; + /** + * @example { + * "path": "/path/to/model", + * "name": "model_name", + * "base": "sd-1", + * "type": "main", + * "format": "checkpoint", + * "config": "configs/stable-diffusion/v1-inference.yaml", + * "description": "Model description", + * "variant": "normal" + * } + */ + "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"] | components["schemas"]["ONNXSDXLConfig"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; }; }; responses: { /** @description The model added successfully */ 201: { content: { - "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; + "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"] | components["schemas"]["ONNXSDXLConfig"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; }; }; /** @description There is already a model corresponding to this path or repo_id */ @@ -13814,22 +11543,81 @@ export type operations = { }; }; /** - * List Model Install Jobs - * @description Return list of model install jobs. + * Heuristic Import + * @description Install a model using a string identifier. + * + * `source` can be any of the following. + * + * 1. A path on the local filesystem ('C:\users\fred\model.safetensors') + * 2. A Url pointing to a single downloadable model file + * 3. A HuggingFace repo_id with any of the following formats: + * - model/name + * - model/name:fp16:vae + * - model/name::vae -- use default precision + * - model/name:fp16:path/to/model.safetensors + * - model/name::path/to/model.safetensors + * + * `config` is an optional dict containing model configuration values that will override + * the ones that are probed automatically. + * + * `access_token` is an optional access token for use with Urls that require + * authentication. + * + * Models will be downloaded, probed, configured and installed in a + * series of background threads. The return object has `status` attribute + * that can be used to monitor progress. + * + * See the documentation for `import_model_record` for more information on + * interpreting the job information returned by this route. */ - list_model_install_jobs: { + heuristic_import_model: { + parameters: { + query: { + source: string; + access_token?: string | null; + }; + }; + requestBody?: { + content: { + /** + * @example { + * "name": "modelT", + * "description": "antique cars" + * } + */ + "application/json": Record | null; + }; + }; responses: { - /** @description Successful Response */ - 200: { + /** @description The model imported successfully */ + 201: { content: { - "application/json": components["schemas"]["ModelInstallJob"][]; + "application/json": components["schemas"]["ModelInstallJob"]; + }; + }; + /** @description There is already a model corresponding to this path or repo_id */ + 409: { + content: never; + }; + /** @description Unrecognized file/folder format */ + 415: { + content: never; + }; + /** @description Validation Error */ + 422: { + content: { + "application/json": components["schemas"]["HTTPValidationError"]; }; }; + /** @description The model appeared to import successfully, but could not be found in the model manager */ + 424: { + content: never; + }; }; }; /** * Import Model - * @description Add a model using its local path, repo_id, or remote URL. + * @description Install a model using its local path, repo_id, or remote URL. * * Models will be downloaded, probed, configured and installed in a * series of background threads. The return object has `status` attribute @@ -13840,32 +11628,38 @@ export type operations = { * appropriate value: * * * To install a local path using LocalModelSource, pass a source of form: - * `{ + * ``` + * { * "type": "local", * "path": "/path/to/model", * "inplace": false - * }` - * The "inplace" flag, if true, will register the model in place in its - * current filesystem location. Otherwise, the model will be copied - * into the InvokeAI models directory. + * } + * ``` + * The "inplace" flag, if true, will register the model in place in its + * current filesystem location. Otherwise, the model will be copied + * into the InvokeAI models directory. * * * To install a HuggingFace repo_id using HFModelSource, pass a source of form: - * `{ + * ``` + * { * "type": "hf", * "repo_id": "stabilityai/stable-diffusion-2.0", * "variant": "fp16", * "subfolder": "vae", * "access_token": "f5820a918aaf01" - * }` - * The `variant`, `subfolder` and `access_token` fields are optional. + * } + * ``` + * The `variant`, `subfolder` and `access_token` fields are optional. * * * To install a remote model using an arbitrary URL, pass: - * `{ + * ``` + * { * "type": "url", * "url": "http://www.civitai.com/models/123456", * "access_token": "f5820a918aaf01" - * }` - * The `access_token` field is optonal + * } + * ``` + * The `access_token` field is optonal * * The model's configuration record will be probed and filled in * automatically. To override the default guesses, pass "metadata" @@ -13874,19 +11668,19 @@ export type operations = { * Installation occurs in the background. Either use list_model_install_jobs() * to poll for completion, or listen on the event bus for the following events: * - * "model_install_running" - * "model_install_completed" - * "model_install_error" + * * "model_install_running" + * * "model_install_completed" + * * "model_install_error" * * On successful completion, the event's payload will contain the field "key" * containing the installed ID of the model. On an error, the event's payload * will contain the fields "error_type" and "error" describing the nature of the * error and its traceback, respectively. */ - import_model_record: { + import_model: { requestBody: { content: { - "application/json": components["schemas"]["Body_import_model_record"]; + "application/json": components["schemas"]["Body_import_model"]; }; }; responses: { @@ -13916,6 +11710,37 @@ export type operations = { }; }; }; + /** + * List Model Install Jobs + * @description Return the list of model install jobs. + * + * Install jobs have a numeric `id`, a `status`, and other fields that provide information on + * the nature of the job and its progress. The `status` is one of: + * + * * "waiting" -- Job is waiting in the queue to run + * * "downloading" -- Model file(s) are downloading + * * "running" -- Model has downloaded and the model probing and registration process is running + * * "completed" -- Installation completed successfully + * * "error" -- An error occurred. Details will be in the "error_type" and "error" fields. + * * "cancelled" -- Job was cancelled before completion. + * + * Once completed, information about the model such as its size, base + * model, type, and metadata can be retrieved from the `config_out` + * field. For multi-file models such as diffusers, information on individual files + * can be retrieved from `download_parts`. + * + * See the example and schema below for more information. + */ + list_model_install_jobs: { + responses: { + /** @description Successful Response */ + 200: { + content: { + "application/json": components["schemas"]["ModelInstallJob"][]; + }; + }; + }; + }; /** * Prune Model Install Jobs * @description Prune all completed and errored jobs from the install job list. @@ -13940,7 +11765,8 @@ export type operations = { }; /** * Get Model Install Job - * @description Return model install job corresponding to the given source. + * @description Return model install job corresponding to the given source. See the documentation for 'List Model Install Jobs' + * for information on the format of the return value. */ get_model_install_job: { parameters: { @@ -14023,16 +11849,59 @@ export type operations = { }; }; }; + /** + * Convert Model + * @description Permanently convert a model into diffusers format, replacing the safetensors version. + * Note that during the conversion process the key and model hash will change. + * The return value is the model configuration for the converted model. + */ + convert_model: { + parameters: { + path: { + /** @description Unique key of the safetensors main model to convert to diffusers format. */ + key: string; + }; + }; + responses: { + /** @description Model converted successfully */ + 200: { + content: { + "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"] | components["schemas"]["ONNXSDXLConfig"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; + }; + }; + /** @description Bad request */ + 400: { + content: never; + }; + /** @description Model not found */ + 404: { + content: never; + }; + /** @description There is already a model registered at this location */ + 409: { + content: never; + }; + /** @description Validation Error */ + 422: { + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; /** * Merge - * @description Merge diffusers models. - * - * keys: List of 2-3 model keys to merge together. All models must use the same base type. - * merged_model_name: Name for the merged model [Concat model names] - * alpha: Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5] - * force: If true, force the merge even if the models were generated by different versions of the diffusers library [False] - * interp: Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum] - * merge_dest_directory: Specify a directory to store the merged model in [models directory] + * @description Merge diffusers models. The process is controlled by a set parameters provided in the body of the request. + * ``` + * Argument Description [default] + * -------- ---------------------- + * keys List of 2-3 model keys to merge together. All models must use the same base type. + * merged_model_name Name for the merged model [Concat model names] + * alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5] + * force If true, force the merge even if the models were generated by different versions of the diffusers library [False] + * interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum] + * merge_dest_directory Specify a directory to store the merged model in [models directory] + * ``` */ merge: { requestBody: { @@ -14041,12 +11910,24 @@ export type operations = { }; }; responses: { - /** @description Successful Response */ + /** @description Model converted successfully */ 200: { content: { - "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; + "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"] | components["schemas"]["ONNXSDXLConfig"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; }; }; + /** @description Bad request */ + 400: { + content: never; + }; + /** @description Model not found */ + 404: { + content: never; + }; + /** @description There is already a model registered at this location */ + 409: { + content: never; + }; /** @description Validation Error */ 422: { content: { diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index 1382fbe275a..f9a1decf655 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -3,7 +3,7 @@ import type { EntityState } from '@reduxjs/toolkit'; import type { components, paths } from 'services/api/schema'; import type { O } from 'ts-toolbelt'; -type s = components['schemas']; +export type S = components['schemas']; export type ImageCache = EntityState; @@ -23,60 +23,60 @@ export type BatchConfig = export type EnqueueBatchResult = components['schemas']['EnqueueBatchResult']; -export type InputFieldJSONSchemaExtra = s['InputFieldJSONSchemaExtra']; -export type OutputFieldJSONSchemaExtra = s['OutputFieldJSONSchemaExtra']; -export type InvocationJSONSchemaExtra = s['UIConfigBase']; +export type InputFieldJSONSchemaExtra = S['InputFieldJSONSchemaExtra']; +export type OutputFieldJSONSchemaExtra = S['OutputFieldJSONSchemaExtra']; +export type InvocationJSONSchemaExtra = S['UIConfigBase']; // App Info -export type AppVersion = s['AppVersion']; -export type AppConfig = s['AppConfig']; -export type AppDependencyVersions = s['AppDependencyVersions']; +export type AppVersion = S['AppVersion']; +export type AppConfig = S['AppConfig']; +export type AppDependencyVersions = S['AppDependencyVersions']; // Images -export type ImageDTO = s['ImageDTO']; -export type BoardDTO = s['BoardDTO']; -export type BoardChanges = s['BoardChanges']; -export type ImageChanges = s['ImageRecordChanges']; -export type ImageCategory = s['ImageCategory']; -export type ResourceOrigin = s['ResourceOrigin']; -export type ImageField = s['ImageField']; -export type OffsetPaginatedResults_BoardDTO_ = s['OffsetPaginatedResults_BoardDTO_']; -export type OffsetPaginatedResults_ImageDTO_ = s['OffsetPaginatedResults_ImageDTO_']; +export type ImageDTO = S['ImageDTO']; +export type BoardDTO = S['BoardDTO']; +export type BoardChanges = S['BoardChanges']; +export type ImageChanges = S['ImageRecordChanges']; +export type ImageCategory = S['ImageCategory']; +export type ResourceOrigin = S['ResourceOrigin']; +export type ImageField = S['ImageField']; +export type OffsetPaginatedResults_BoardDTO_ = S['OffsetPaginatedResults_BoardDTO_']; +export type OffsetPaginatedResults_ImageDTO_ = S['OffsetPaginatedResults_ImageDTO_']; // Models -export type ModelType = s['invokeai__backend__model_management__models__base__ModelType']; -export type SubModelType = s['SubModelType']; -export type BaseModelType = s['invokeai__backend__model_management__models__base__BaseModelType']; -export type MainModelField = s['MainModelField']; -export type VAEModelField = s['VAEModelField']; -export type LoRAModelField = s['LoRAModelField']; -export type LoRAModelFormat = s['LoRAModelFormat']; -export type ControlNetModelField = s['ControlNetModelField']; -export type IPAdapterModelField = s['IPAdapterModelField']; -export type T2IAdapterModelField = s['T2IAdapterModelField']; -export type ModelsList = s['invokeai__app__api__routers__models__ModelsList']; -export type ControlField = s['ControlField']; -export type IPAdapterField = s['IPAdapterField']; +export type ModelType = S['ModelType']; +export type SubModelType = S['SubModelType']; +export type BaseModelType = S['BaseModelType']; +export type MainModelField = S['MainModelField']; +export type VAEModelField = S['VAEModelField']; +export type LoRAModelField = S['LoRAModelField']; +export type LoRAModelFormat = S['LoRAModelFormat']; +export type ControlNetModelField = S['ControlNetModelField']; +export type IPAdapterModelField = S['IPAdapterModelField']; +export type T2IAdapterModelField = S['T2IAdapterModelField']; +export type ModelsList = S['invokeai__app__api__routers__models__ModelsList']; +export type ControlField = S['ControlField']; +export type IPAdapterField = S['IPAdapterField']; // Model Configs -export type LoRAModelConfig = s['LoRAModelConfig']; -export type VaeModelConfig = s['VaeModelConfig']; -export type ControlNetModelCheckpointConfig = s['ControlNetModelCheckpointConfig']; -export type ControlNetModelDiffusersConfig = s['ControlNetModelDiffusersConfig']; +export type LoRAModelConfig = S['LoRAModelConfig']; +export type VaeModelConfig = S['VaeModelConfig']; +export type ControlNetModelCheckpointConfig = S['ControlNetModelCheckpointConfig']; +export type ControlNetModelDiffusersConfig = S['ControlNetModelDiffusersConfig']; export type ControlNetModelConfig = ControlNetModelCheckpointConfig | ControlNetModelDiffusersConfig; -export type IPAdapterModelInvokeAIConfig = s['IPAdapterModelInvokeAIConfig']; +export type IPAdapterModelInvokeAIConfig = S['IPAdapterModelInvokeAIConfig']; export type IPAdapterModelConfig = IPAdapterModelInvokeAIConfig; -export type T2IAdapterModelDiffusersConfig = s['T2IAdapterModelDiffusersConfig']; +export type T2IAdapterModelDiffusersConfig = S['T2IAdapterModelDiffusersConfig']; export type T2IAdapterModelConfig = T2IAdapterModelDiffusersConfig; -export type TextualInversionModelConfig = s['TextualInversionModelConfig']; +export type TextualInversionModelConfig = S['TextualInversionModelConfig']; export type DiffusersModelConfig = - | s['StableDiffusion1ModelDiffusersConfig'] - | s['StableDiffusion2ModelDiffusersConfig'] - | s['StableDiffusionXLModelDiffusersConfig']; + | S['StableDiffusion1ModelDiffusersConfig'] + | S['StableDiffusion2ModelDiffusersConfig'] + | S['StableDiffusionXLModelDiffusersConfig']; export type CheckpointModelConfig = - | s['StableDiffusion1ModelCheckpointConfig'] - | s['StableDiffusion2ModelCheckpointConfig'] - | s['StableDiffusionXLModelCheckpointConfig']; + | S['StableDiffusion1ModelCheckpointConfig'] + | S['StableDiffusion2ModelCheckpointConfig'] + | S['StableDiffusionXLModelCheckpointConfig']; export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig; export type AnyModelConfig = | LoRAModelConfig @@ -87,88 +87,87 @@ export type AnyModelConfig = | TextualInversionModelConfig | MainModelConfig; -export type MergeModelConfig = s['Body_merge_models']; -export type ImportModelConfig = s['Body_import_model']; +export type MergeModelConfig = S['Body_merge_models']; +export type ImportModelConfig = S['Body_import_model']; // Graphs -export type Graph = s['Graph']; +export type Graph = S['Graph']; export type NonNullableGraph = O.Required; -export type Edge = s['Edge']; -export type GraphExecutionState = s['GraphExecutionState']; -export type Batch = s['Batch']; -export type SessionQueueItemDTO = s['SessionQueueItemDTO']; -export type SessionQueueItem = s['SessionQueueItem']; -export type WorkflowRecordOrderBy = s['WorkflowRecordOrderBy']; -export type SQLiteDirection = s['SQLiteDirection']; -export type WorkflowDTO = s['WorkflowRecordDTO']; -export type WorkflowRecordListItemDTO = s['WorkflowRecordListItemDTO']; +export type Edge = S['Edge']; +export type GraphExecutionState = S['GraphExecutionState']; +export type Batch = S['Batch']; +export type SessionQueueItemDTO = S['SessionQueueItemDTO']; +export type SessionQueueItem = S['SessionQueueItem']; +export type WorkflowRecordOrderBy = S['WorkflowRecordOrderBy']; +export type SQLiteDirection = S['SQLiteDirection']; +export type WorkflowDTO = S['WorkflowRecordDTO']; +export type WorkflowRecordListItemDTO = S['WorkflowRecordListItemDTO']; // General nodes -export type CollectInvocation = s['CollectInvocation']; -export type IterateInvocation = s['IterateInvocation']; -export type RangeInvocation = s['RangeInvocation']; -export type RandomRangeInvocation = s['RandomRangeInvocation']; -export type RangeOfSizeInvocation = s['RangeOfSizeInvocation']; -export type ImageResizeInvocation = s['ImageResizeInvocation']; -export type ImageBlurInvocation = s['ImageBlurInvocation']; -export type ImageScaleInvocation = s['ImageScaleInvocation']; -export type InfillPatchMatchInvocation = s['InfillPatchMatchInvocation']; -export type InfillTileInvocation = s['InfillTileInvocation']; -export type CreateDenoiseMaskInvocation = s['CreateDenoiseMaskInvocation']; -export type MaskEdgeInvocation = s['MaskEdgeInvocation']; -export type RandomIntInvocation = s['RandomIntInvocation']; -export type CompelInvocation = s['CompelInvocation']; -export type DynamicPromptInvocation = s['DynamicPromptInvocation']; -export type NoiseInvocation = s['NoiseInvocation']; -export type DenoiseLatentsInvocation = s['DenoiseLatentsInvocation']; -export type SDXLLoraLoaderInvocation = s['SDXLLoraLoaderInvocation']; -export type ImageToLatentsInvocation = s['ImageToLatentsInvocation']; -export type LatentsToImageInvocation = s['LatentsToImageInvocation']; -export type ImageCollectionInvocation = s['ImageCollectionInvocation']; -export type MainModelLoaderInvocation = s['MainModelLoaderInvocation']; -export type LoraLoaderInvocation = s['LoraLoaderInvocation']; -export type ESRGANInvocation = s['ESRGANInvocation']; -export type DivideInvocation = s['DivideInvocation']; -export type ImageNSFWBlurInvocation = s['ImageNSFWBlurInvocation']; -export type ImageWatermarkInvocation = s['ImageWatermarkInvocation']; -export type SeamlessModeInvocation = s['SeamlessModeInvocation']; -export type LinearUIOutputInvocation = s['LinearUIOutputInvocation']; -export type MetadataInvocation = s['MetadataInvocation']; -export type CoreMetadataInvocation = s['CoreMetadataInvocation']; -export type MetadataItemInvocation = s['MetadataItemInvocation']; -export type MergeMetadataInvocation = s['MergeMetadataInvocation']; -export type IPAdapterMetadataField = s['IPAdapterMetadataField']; -export type T2IAdapterField = s['T2IAdapterField']; -export type LoRAMetadataField = s['LoRAMetadataField']; +export type CollectInvocation = S['CollectInvocation']; +export type IterateInvocation = S['IterateInvocation']; +export type RangeInvocation = S['RangeInvocation']; +export type RandomRangeInvocation = S['RandomRangeInvocation']; +export type RangeOfSizeInvocation = S['RangeOfSizeInvocation']; +export type ImageResizeInvocation = S['ImageResizeInvocation']; +export type ImageBlurInvocation = S['ImageBlurInvocation']; +export type ImageScaleInvocation = S['ImageScaleInvocation']; +export type InfillPatchMatchInvocation = S['InfillPatchMatchInvocation']; +export type InfillTileInvocation = S['InfillTileInvocation']; +export type CreateDenoiseMaskInvocation = S['CreateDenoiseMaskInvocation']; +export type MaskEdgeInvocation = S['MaskEdgeInvocation']; +export type RandomIntInvocation = S['RandomIntInvocation']; +export type CompelInvocation = S['CompelInvocation']; +export type DynamicPromptInvocation = S['DynamicPromptInvocation']; +export type NoiseInvocation = S['NoiseInvocation']; +export type DenoiseLatentsInvocation = S['DenoiseLatentsInvocation']; +export type SDXLLoraLoaderInvocation = S['SDXLLoraLoaderInvocation']; +export type ImageToLatentsInvocation = S['ImageToLatentsInvocation']; +export type LatentsToImageInvocation = S['LatentsToImageInvocation']; +export type ImageCollectionInvocation = S['ImageCollectionInvocation']; +export type MainModelLoaderInvocation = S['MainModelLoaderInvocation']; +export type LoraLoaderInvocation = S['LoraLoaderInvocation']; +export type ESRGANInvocation = S['ESRGANInvocation']; +export type DivideInvocation = S['DivideInvocation']; +export type ImageNSFWBlurInvocation = S['ImageNSFWBlurInvocation']; +export type ImageWatermarkInvocation = S['ImageWatermarkInvocation']; +export type SeamlessModeInvocation = S['SeamlessModeInvocation']; +export type MetadataInvocation = S['MetadataInvocation']; +export type CoreMetadataInvocation = S['CoreMetadataInvocation']; +export type MetadataItemInvocation = S['MetadataItemInvocation']; +export type MergeMetadataInvocation = S['MergeMetadataInvocation']; +export type IPAdapterMetadataField = S['IPAdapterMetadataField']; +export type T2IAdapterField = S['T2IAdapterField']; +export type LoRAMetadataField = S['LoRAMetadataField']; // ControlNet Nodes -export type ControlNetInvocation = s['ControlNetInvocation']; -export type T2IAdapterInvocation = s['T2IAdapterInvocation']; -export type IPAdapterInvocation = s['IPAdapterInvocation']; -export type CannyImageProcessorInvocation = s['CannyImageProcessorInvocation']; -export type ColorMapImageProcessorInvocation = s['ColorMapImageProcessorInvocation']; -export type ContentShuffleImageProcessorInvocation = s['ContentShuffleImageProcessorInvocation']; -export type DepthAnythingImageProcessorInvocation = s['DepthAnythingImageProcessorInvocation']; -export type HedImageProcessorInvocation = s['HedImageProcessorInvocation']; -export type LineartAnimeImageProcessorInvocation = s['LineartAnimeImageProcessorInvocation']; -export type LineartImageProcessorInvocation = s['LineartImageProcessorInvocation']; -export type MediapipeFaceProcessorInvocation = s['MediapipeFaceProcessorInvocation']; -export type MidasDepthImageProcessorInvocation = s['MidasDepthImageProcessorInvocation']; -export type MlsdImageProcessorInvocation = s['MlsdImageProcessorInvocation']; -export type NormalbaeImageProcessorInvocation = s['NormalbaeImageProcessorInvocation']; -export type DWOpenposeImageProcessorInvocation = s['DWOpenposeImageProcessorInvocation']; -export type PidiImageProcessorInvocation = s['PidiImageProcessorInvocation']; -export type ZoeDepthImageProcessorInvocation = s['ZoeDepthImageProcessorInvocation']; +export type ControlNetInvocation = S['ControlNetInvocation']; +export type T2IAdapterInvocation = S['T2IAdapterInvocation']; +export type IPAdapterInvocation = S['IPAdapterInvocation']; +export type CannyImageProcessorInvocation = S['CannyImageProcessorInvocation']; +export type ColorMapImageProcessorInvocation = S['ColorMapImageProcessorInvocation']; +export type ContentShuffleImageProcessorInvocation = S['ContentShuffleImageProcessorInvocation']; +export type DepthAnythingImageProcessorInvocation = S['DepthAnythingImageProcessorInvocation']; +export type HedImageProcessorInvocation = S['HedImageProcessorInvocation']; +export type LineartAnimeImageProcessorInvocation = S['LineartAnimeImageProcessorInvocation']; +export type LineartImageProcessorInvocation = S['LineartImageProcessorInvocation']; +export type MediapipeFaceProcessorInvocation = S['MediapipeFaceProcessorInvocation']; +export type MidasDepthImageProcessorInvocation = S['MidasDepthImageProcessorInvocation']; +export type MlsdImageProcessorInvocation = S['MlsdImageProcessorInvocation']; +export type NormalbaeImageProcessorInvocation = S['NormalbaeImageProcessorInvocation']; +export type DWOpenposeImageProcessorInvocation = S['DWOpenposeImageProcessorInvocation']; +export type PidiImageProcessorInvocation = S['PidiImageProcessorInvocation']; +export type ZoeDepthImageProcessorInvocation = S['ZoeDepthImageProcessorInvocation']; // Node Outputs -export type ImageOutput = s['ImageOutput']; -export type StringOutput = s['StringOutput']; -export type FloatOutput = s['FloatOutput']; -export type IntegerOutput = s['IntegerOutput']; -export type IterateInvocationOutput = s['IterateInvocationOutput']; -export type CollectInvocationOutput = s['CollectInvocationOutput']; -export type LatentsOutput = s['LatentsOutput']; -export type GraphInvocationOutput = s['GraphInvocationOutput']; +export type ImageOutput = S['ImageOutput']; +export type StringOutput = S['StringOutput']; +export type FloatOutput = S['FloatOutput']; +export type IntegerOutput = S['IntegerOutput']; +export type IterateInvocationOutput = S['IterateInvocationOutput']; +export type CollectInvocationOutput = S['CollectInvocationOutput']; +export type LatentsOutput = S['LatentsOutput']; +export type GraphInvocationOutput = S['GraphInvocationOutput']; // Post-image upload actions, controls workflows when images are uploaded diff --git a/invokeai/frontend/web/vite.config.mts b/invokeai/frontend/web/vite.config.mts index b76dd24b628..f4dbae71232 100644 --- a/invokeai/frontend/web/vite.config.mts +++ b/invokeai/frontend/web/vite.config.mts @@ -1,12 +1,93 @@ +/// +import react from '@vitejs/plugin-react-swc'; +import path from 'path'; +import { visualizer } from 'rollup-plugin-visualizer'; +import type { PluginOption } from 'vite'; import { defineConfig } from 'vite'; - -import { appConfig } from './config/vite.app.config.mjs'; -import { packageConfig } from './config/vite.package.config.mjs'; +import cssInjectedByJsPlugin from 'vite-plugin-css-injected-by-js'; +import dts from 'vite-plugin-dts'; +import eslint from 'vite-plugin-eslint'; +import tsconfigPaths from 'vite-tsconfig-paths'; export default defineConfig(({ mode }) => { if (mode === 'package') { - return packageConfig; + return { + base: './', + plugins: [ + react(), + eslint(), + tsconfigPaths(), + visualizer() as unknown as PluginOption, + dts({ + insertTypesEntry: true, + }), + cssInjectedByJsPlugin(), + ], + build: { + cssCodeSplit: true, + lib: { + entry: path.resolve(__dirname, '../src/index.ts'), + name: 'InvokeAIUI', + fileName: (format) => `invoke-ai-ui.${format}.js`, + }, + rollupOptions: { + external: ['react', 'react-dom', '@emotion/react', '@chakra-ui/react', '@invoke-ai/ui-library'], + output: { + globals: { + react: 'React', + 'react-dom': 'ReactDOM', + '@emotion/react': 'EmotionReact', + '@invoke-ai/ui-library': 'UiLibrary', + }, + }, + }, + }, + resolve: { + alias: { + app: path.resolve(__dirname, '../src/app'), + assets: path.resolve(__dirname, '../src/assets'), + common: path.resolve(__dirname, '../src/common'), + features: path.resolve(__dirname, '../src/features'), + services: path.resolve(__dirname, '../src/services'), + theme: path.resolve(__dirname, '../src/theme'), + }, + }, + }; } - return appConfig; + return { + base: './', + plugins: [react(), mode !== 'test' && eslint(), tsconfigPaths(), visualizer() as unknown as PluginOption], + build: { + chunkSizeWarningLimit: 1500, + }, + server: { + // Proxy HTTP requests to the flask server + proxy: { + // Proxy socket.io to the nodes socketio server + '/ws/socket.io': { + target: 'ws://127.0.0.1:9090', + ws: true, + }, + // Proxy openapi schema definiton + '/openapi.json': { + target: 'http://127.0.0.1:9090/openapi.json', + rewrite: (path) => path.replace(/^\/openapi.json/, ''), + changeOrigin: true, + }, + // proxy nodes api + '/api/v1': { + target: 'http://127.0.0.1:9090/api/v1', + rewrite: (path) => path.replace(/^\/api\/v1/, ''), + changeOrigin: true, + }, + }, + }, + test: { + typecheck: { + enabled: true, + ignoreSourceErrors: true, + }, + }, + }; }); diff --git a/invokeai/invocation_api/__init__.py b/invokeai/invocation_api/__init__.py new file mode 100644 index 00000000000..e110b5a2db3 --- /dev/null +++ b/invokeai/invocation_api/__init__.py @@ -0,0 +1,187 @@ +""" +This file re-exports all the public API for invocations. This is the only file that should be imported by custom nodes. + +TODO(psyche): Do we want to dogfood this? +""" + +from invokeai.app.invocations.baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + Classification, + invocation, + invocation_output, +) +from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES +from invokeai.app.invocations.fields import ( + BoardField, + ColorField, + ConditioningField, + DenoiseMaskField, + FieldDescriptions, + FieldKind, + ImageField, + Input, + InputField, + LatentsField, + MetadataField, + OutputField, + UIComponent, + UIType, + WithMetadata, + WithWorkflow, +) +from invokeai.app.invocations.latent import SchedulerOutput +from invokeai.app.invocations.metadata import MetadataItemField, MetadataItemOutput, MetadataOutput +from invokeai.app.invocations.model import ( + ClipField, + CLIPOutput, + LoraInfo, + LoraLoaderOutput, + LoRAModelField, + MainModelField, + ModelInfo, + ModelLoaderOutput, + SDXLLoraLoaderOutput, + UNetField, + UNetOutput, + VaeField, + VAEModelField, + VAEOutput, +) +from invokeai.app.invocations.primitives import ( + BooleanCollectionOutput, + BooleanOutput, + ColorCollectionOutput, + ColorOutput, + ConditioningCollectionOutput, + ConditioningOutput, + DenoiseMaskOutput, + FloatCollectionOutput, + FloatOutput, + ImageCollectionOutput, + ImageOutput, + IntegerCollectionOutput, + IntegerOutput, + LatentsCollectionOutput, + LatentsOutput, + StringCollectionOutput, + StringOutput, +) +from invokeai.app.services.boards.boards_common import BoardDTO +from invokeai.app.services.config.config_default import InvokeAIAppConfig +from invokeai.app.services.image_records.image_records_common import ImageCategory +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID +from invokeai.app.util.misc import SEED_MAX, get_random_seed +from invokeai.backend.model_management.model_manager import LoadedModelInfo +from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType +from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( + BasicConditioningInfo, + ConditioningFieldData, + ExtraConditioningInfo, + SDXLConditioningInfo, +) +from invokeai.backend.util.devices import CPU_DEVICE, CUDA_DEVICE, MPS_DEVICE, choose_precision, choose_torch_device +from invokeai.version import __version__ + +__all__ = [ + # invokeai.app.invocations.baseinvocation + "BaseInvocation", + "BaseInvocationOutput", + "Classification", + "invocation", + "invocation_output", + # invokeai.app.services.shared.invocation_context + "InvocationContext", + # invokeai.app.invocations.fields + "BoardField", + "ColorField", + "ConditioningField", + "DenoiseMaskField", + "FieldDescriptions", + "FieldKind", + "ImageField", + "Input", + "InputField", + "LatentsField", + "MetadataField", + "OutputField", + "UIComponent", + "UIType", + "WithMetadata", + "WithWorkflow", + # invokeai.app.invocations.latent + "SchedulerOutput", + # invokeai.app.invocations.metadata + "MetadataItemField", + "MetadataItemOutput", + "MetadataOutput", + # invokeai.app.invocations.model + "ModelInfo", + "LoraInfo", + "UNetField", + "ClipField", + "VaeField", + "MainModelField", + "LoRAModelField", + "VAEModelField", + "UNetOutput", + "VAEOutput", + "CLIPOutput", + "ModelLoaderOutput", + "LoraLoaderOutput", + "SDXLLoraLoaderOutput", + # invokeai.app.invocations.primitives + "BooleanCollectionOutput", + "BooleanOutput", + "ColorCollectionOutput", + "ColorOutput", + "ConditioningCollectionOutput", + "ConditioningOutput", + "DenoiseMaskOutput", + "FloatCollectionOutput", + "FloatOutput", + "ImageCollectionOutput", + "ImageOutput", + "IntegerCollectionOutput", + "IntegerOutput", + "LatentsCollectionOutput", + "LatentsOutput", + "StringCollectionOutput", + "StringOutput", + # invokeai.app.services.image_records.image_records_common + "ImageCategory", + # invokeai.app.services.boards.boards_common + "BoardDTO", + # invokeai.backend.stable_diffusion.diffusion.conditioning_data + "BasicConditioningInfo", + "ConditioningFieldData", + "ExtraConditioningInfo", + "SDXLConditioningInfo", + # invokeai.backend.stable_diffusion.diffusers_pipeline + "PipelineIntermediateState", + # invokeai.app.services.workflow_records.workflow_records_common + "WorkflowWithoutID", + # invokeai.app.services.config.config_default + "InvokeAIAppConfig", + # invokeai.backend.model_management.model_manager + "LoadedModelInfo", + # invokeai.backend.model_management.models.base + "BaseModelType", + "ModelType", + "SubModelType", + # invokeai.app.invocations.constants + "SCHEDULER_NAME_VALUES", + # invokeai.version + "__version__", + # invokeai.backend.util.devices + "choose_precision", + "choose_torch_device", + "CPU_DEVICE", + "CUDA_DEVICE", + "MPS_DEVICE", + # invokeai.app.util.misc + "SEED_MAX", + "get_random_seed", +] diff --git a/pyproject.toml b/pyproject.toml index 7f4b0d77f25..f57607bc0af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,6 +66,7 @@ dependencies = [ "albumentations", "click", "datasets", + "Deprecated", "dnspython~=2.4.0", "dynamicprompts", "easing-functions", @@ -135,8 +136,7 @@ dependencies = [ # full commands "invokeai-configure" = "invokeai.frontend.install.invokeai_configure:invokeai_configure" -"invokeai-merge" = "invokeai.frontend.merge:invokeai_merge_diffusers" -"invokeai-merge2" = "invokeai.frontend.merge.merge_diffusers2:main" +"invokeai-merge" = "invokeai.frontend.merge.merge_diffusers:main" "invokeai-ti" = "invokeai.frontend.training:invokeai_textual_inversion" "invokeai-model-install" = "invokeai.frontend.install.model_install:main" "invokeai-model-install2" = "invokeai.frontend.install.model_install2:main" # will eventually be renamed to invokeai-model-install @@ -169,6 +169,7 @@ version = { attr = "invokeai.version.__version__" } "invokeai.frontend.web.static*", "invokeai.configs*", "invokeai.app*", + "invokeai.invocation_api*", ] [tool.setuptools.package-data] @@ -244,7 +245,7 @@ module = [ "invokeai.app.services.invocation_stats.invocation_stats_default", "invokeai.app.services.model_manager.model_manager_base", "invokeai.app.services.model_manager.model_manager_default", - "invokeai.app.services.model_records.model_records_sql", + "invokeai.app.services.model_manager.store.model_records_sql", "invokeai.app.util.controlnet_utils", "invokeai.backend.image_util.txt2mask", "invokeai.backend.image_util.safety_checker", @@ -280,3 +281,38 @@ module = [ "invokeai.frontend.install.model_install", ] #=== End: MyPy + +[tool.pyright] +# Start from strict mode +typeCheckingMode = "strict" +# This errors whenever an import is missing a type stub file - way too noisy +reportMissingTypeStubs = "none" +# These are the rest of the rules enabled by strict mode - enable them @ warning +reportConstantRedefinition = "warning" +reportDeprecated = "warning" +reportDuplicateImport = "warning" +reportIncompleteStub = "warning" +reportInconsistentConstructor = "warning" +reportInvalidStubStatement = "warning" +reportMatchNotExhaustive = "warning" +reportMissingParameterType = "warning" +reportMissingTypeArgument = "warning" +reportPrivateUsage = "warning" +reportTypeCommentUsage = "warning" +reportUnknownArgumentType = "warning" +reportUnknownLambdaType = "warning" +reportUnknownMemberType = "warning" +reportUnknownParameterType = "warning" +reportUnknownVariableType = "warning" +reportUnnecessaryCast = "warning" +reportUnnecessaryComparison = "warning" +reportUnnecessaryContains = "warning" +reportUnnecessaryIsInstance = "warning" +reportUnusedClass = "warning" +reportUnusedImport = "warning" +reportUnusedFunction = "warning" +reportUnusedVariable = "warning" +reportUntypedBaseClass = "warning" +reportUntypedClassDecorator = "warning" +reportUntypedFunctionDecorator = "warning" +reportUntypedNamedTuple = "warning" diff --git a/tests/aa_nodes/test_graph_execution_state.py b/tests/aa_nodes/test_graph_execution_state.py index fab1fa4598f..f839a4a8785 100644 --- a/tests/aa_nodes/test_graph_execution_state.py +++ b/tests/aa_nodes/test_graph_execution_state.py @@ -21,7 +21,6 @@ from invokeai.app.services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService -from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID from invokeai.app.services.shared.graph import ( CollectInvocation, Graph, @@ -61,12 +60,9 @@ def mock_services() -> InvocationServices: image_records=None, # type: ignore images=None, # type: ignore invocation_cache=MemoryInvocationCache(max_cache_size=0), - latents=None, # type: ignore logger=logging, # type: ignore model_manager=None, # type: ignore - model_records=None, # type: ignore download_queue=None, # type: ignore - model_install=None, # type: ignore names=None, # type: ignore performance_statistics=InvocationStatsService(), processor=DefaultInvocationProcessor(), @@ -75,6 +71,8 @@ def mock_services() -> InvocationServices: session_queue=None, # type: ignore urls=None, # type: ignore workflow_records=None, # type: ignore + tensors=None, + conditioning=None, ) @@ -86,12 +84,16 @@ def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[B print(f"invoking {n.id}: {type(n)}") o = n.invoke( InvocationContext( - queue_batch_id="1", - queue_item_id=1, - queue_id=DEFAULT_QUEUE_ID, - services=services, - graph_execution_state_id="1", - workflow=None, + conditioning=None, + config=None, + context_data=None, + images=None, + tensors=None, + logger=None, + models=None, + util=None, + boards=None, + services=None, ) ) g.complete(n.id, o) diff --git a/tests/aa_nodes/test_invoker.py b/tests/aa_nodes/test_invoker.py index 2ae4eab58a0..774f7501dc2 100644 --- a/tests/aa_nodes/test_invoker.py +++ b/tests/aa_nodes/test_invoker.py @@ -63,12 +63,9 @@ def mock_services() -> InvocationServices: image_records=None, # type: ignore images=None, # type: ignore invocation_cache=MemoryInvocationCache(max_cache_size=0), - latents=None, # type: ignore logger=logging, # type: ignore model_manager=None, # type: ignore - model_records=None, # type: ignore download_queue=None, # type: ignore - model_install=None, # type: ignore names=None, # type: ignore performance_statistics=InvocationStatsService(), processor=DefaultInvocationProcessor(), @@ -77,6 +74,8 @@ def mock_services() -> InvocationServices: session_queue=None, # type: ignore urls=None, # type: ignore workflow_records=None, # type: ignore + tensors=None, + conditioning=None, ) diff --git a/tests/aa_nodes/test_nodes.py b/tests/aa_nodes/test_nodes.py index bca4e1011f8..aab3d9c7b4b 100644 --- a/tests/aa_nodes/test_nodes.py +++ b/tests/aa_nodes/test_nodes.py @@ -3,13 +3,12 @@ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InputField, - InvocationContext, - OutputField, invocation, invocation_output, ) +from invokeai.app.invocations.fields import InputField, OutputField from invokeai.app.invocations.image import ImageField +from invokeai.app.services.shared.invocation_context import InvocationContext # Define test invocations before importing anything that uses invocations diff --git a/tests/backend/model_manager_2/model_loading/test_model_load.py b/tests/backend/model_manager_2/model_loading/test_model_load.py new file mode 100644 index 00000000000..a7a64e91ac0 --- /dev/null +++ b/tests/backend/model_manager_2/model_loading/test_model_load.py @@ -0,0 +1,22 @@ +""" +Test model loading +""" + +from pathlib import Path + +from invokeai.app.services.model_install import ModelInstallServiceBase +from invokeai.backend.embeddings.textual_inversion import TextualInversionModelRaw +from invokeai.backend.model_manager.load import AnyModelLoader +from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403 + + +def test_loading(mm2_installer: ModelInstallServiceBase, mm2_loader: AnyModelLoader, embedding_file: Path): + store = mm2_installer.record_store + matches = store.search_by_attr(model_name="test_embedding") + assert len(matches) == 0 + key = mm2_installer.register_path(embedding_file) + loaded_model = mm2_loader.load_model(store.get_model(key)) + assert loaded_model is not None + assert loaded_model.config.key == key + with loaded_model as model: + assert isinstance(model, TextualInversionModelRaw) diff --git a/tests/backend/model_manager_2/model_manager_2_fixtures.py b/tests/backend/model_manager_2/model_manager_2_fixtures.py index d6d091befea..d85eab67dd3 100644 --- a/tests/backend/model_manager_2/model_manager_2_fixtures.py +++ b/tests/backend/model_manager_2/model_manager_2_fixtures.py @@ -20,6 +20,7 @@ ModelFormat, ModelType, ) +from invokeai.backend.model_manager.load import AnyModelLoader, ModelCache, ModelConvertCache from invokeai.backend.model_manager.metadata import ModelMetadataStore from invokeai.backend.util.logging import InvokeAILogger from tests.backend.model_manager_2.model_metadata.metadata_examples import ( @@ -89,6 +90,16 @@ def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig: return app_config +@pytest.fixture +def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceSQL) -> AnyModelLoader: + logger = InvokeAILogger.get_logger(config=mm2_app_config) + ram_cache = ModelCache( + logger=logger, max_cache_size=mm2_app_config.ram_cache_size, max_vram_cache_size=mm2_app_config.vram_cache_size + ) + convert_cache = ModelConvertCache(mm2_app_config.models_convert_cache_path) + return AnyModelLoader(app_config=mm2_app_config, logger=logger, ram_cache=ram_cache, convert_cache=convert_cache) + + @pytest.fixture def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL: logger = InvokeAILogger.get_logger(config=mm2_app_config) diff --git a/tests/backend/model_manager_2/util/test_hf_model_select.py b/tests/backend/model_manager_2/util/test_hf_model_select.py index f14d9a6823a..5bef9cb2e19 100644 --- a/tests/backend/model_manager_2/util/test_hf_model_select.py +++ b/tests/backend/model_manager_2/util/test_hf_model_select.py @@ -192,6 +192,7 @@ def sdxl_base_files() -> List[Path]: "text_encoder/model.onnx", "text_encoder_2/config.json", "text_encoder_2/model.onnx", + "text_encoder_2/model.onnx_data", "tokenizer/merges.txt", "tokenizer/special_tokens_map.json", "tokenizer/tokenizer_config.json", @@ -202,6 +203,7 @@ def sdxl_base_files() -> List[Path]: "tokenizer_2/vocab.json", "unet/config.json", "unet/model.onnx", + "unet/model.onnx_data", "vae_decoder/config.json", "vae_decoder/model.onnx", "vae_encoder/config.json", diff --git a/tests/test_model_manager.py b/tests/test_model_manager.py deleted file mode 100644 index 3e48c7ed6fc..00000000000 --- a/tests/test_model_manager.py +++ /dev/null @@ -1,47 +0,0 @@ -from pathlib import Path - -import pytest - -from invokeai.app.services.config.config_default import InvokeAIAppConfig -from invokeai.backend import BaseModelType, ModelManager, ModelType, SubModelType - -BASIC_MODEL_NAME = ("SDXL base", BaseModelType.StableDiffusionXL, ModelType.Main) -VAE_OVERRIDE_MODEL_NAME = ("SDXL with VAE", BaseModelType.StableDiffusionXL, ModelType.Main) -VAE_NULL_OVERRIDE_MODEL_NAME = ("SDXL with empty VAE", BaseModelType.StableDiffusionXL, ModelType.Main) - - -@pytest.fixture -def model_manager(datadir) -> ModelManager: - InvokeAIAppConfig.get_config(root=datadir) - return ModelManager(datadir / "configs" / "relative_sub.models.yaml") - - -def test_get_model_names(model_manager: ModelManager): - names = model_manager.model_names() - assert names[:2] == [BASIC_MODEL_NAME, VAE_OVERRIDE_MODEL_NAME] - - -def test_get_model_path_for_diffusers(model_manager: ModelManager, datadir: Path): - model_config = model_manager._get_model_config(BASIC_MODEL_NAME[1], BASIC_MODEL_NAME[0], BASIC_MODEL_NAME[2]) - top_model_path, is_override = model_manager._get_model_path(model_config) - expected_model_path = datadir / "models" / "sdxl" / "main" / "SDXL base 1_0" - assert top_model_path == expected_model_path - assert not is_override - - -def test_get_model_path_for_overridden_vae(model_manager: ModelManager, datadir: Path): - model_config = model_manager._get_model_config( - VAE_OVERRIDE_MODEL_NAME[1], VAE_OVERRIDE_MODEL_NAME[0], VAE_OVERRIDE_MODEL_NAME[2] - ) - vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae) - expected_vae_path = datadir / "models" / "sdxl" / "vae" / "sdxl-vae-fp16-fix" - assert vae_model_path == expected_vae_path - assert is_override - - -def test_get_model_path_for_null_overridden_vae(model_manager: ModelManager, datadir: Path): - model_config = model_manager._get_model_config( - VAE_NULL_OVERRIDE_MODEL_NAME[1], VAE_NULL_OVERRIDE_MODEL_NAME[0], VAE_NULL_OVERRIDE_MODEL_NAME[2] - ) - vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae) - assert not is_override diff --git a/tests/test_model_probe.py b/tests/test_model_probe.py index 248b7d602fd..be823e2be9f 100644 --- a/tests/test_model_probe.py +++ b/tests/test_model_probe.py @@ -2,8 +2,8 @@ import pytest -from invokeai.backend import BaseModelType -from invokeai.backend.model_management.model_probe import VaeFolderProbe +from invokeai.backend.model_manager import BaseModelType, ModelRepoVariant +from invokeai.backend.model_manager.probe import VaeFolderProbe @pytest.mark.parametrize( @@ -20,3 +20,11 @@ def test_get_base_type(vae_path: str, expected_type: BaseModelType, datadir: Pat probe = VaeFolderProbe(sd1_vae_path) base_type = probe.get_base_type() assert base_type == expected_type + repo_variant = probe.get_repo_variant() + assert repo_variant == ModelRepoVariant.DEFAULT + + +def test_repo_variant(datadir: Path): + probe = VaeFolderProbe(datadir / "vae" / "taesdxl-fp16") + repo_variant = probe.get_repo_variant() + assert repo_variant == ModelRepoVariant.FP16 diff --git a/tests/test_model_probe/vae/taesdxl-fp16/config.json b/tests/test_model_probe/vae/taesdxl-fp16/config.json new file mode 100644 index 00000000000..62f01c3eb44 --- /dev/null +++ b/tests/test_model_probe/vae/taesdxl-fp16/config.json @@ -0,0 +1,37 @@ +{ + "_class_name": "AutoencoderTiny", + "_diffusers_version": "0.20.0.dev0", + "act_fn": "relu", + "decoder_block_out_channels": [ + 64, + 64, + 64, + 64 + ], + "encoder_block_out_channels": [ + 64, + 64, + 64, + 64 + ], + "force_upcast": false, + "in_channels": 3, + "latent_channels": 4, + "latent_magnitude": 3, + "latent_shift": 0.5, + "num_decoder_blocks": [ + 3, + 3, + 3, + 1 + ], + "num_encoder_blocks": [ + 1, + 3, + 3, + 3 + ], + "out_channels": 3, + "scaling_factor": 1.0, + "upsampling_scaling_factor": 2 +} diff --git a/tests/test_model_probe/vae/taesdxl-fp16/diffusion_pytorch_model.fp16.safetensors b/tests/test_model_probe/vae/taesdxl-fp16/diffusion_pytorch_model.fp16.safetensors new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/test_object_serializer_disk.py b/tests/test_object_serializer_disk.py new file mode 100644 index 00000000000..125534c5002 --- /dev/null +++ b/tests/test_object_serializer_disk.py @@ -0,0 +1,172 @@ +import tempfile +from dataclasses import dataclass +from pathlib import Path + +import pytest +import torch + +from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError +from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk +from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache + + +@dataclass +class MockDataclass: + foo: str + + +def count_files(path: Path): + return len(list(path.iterdir())) + + +@pytest.fixture +def obj_serializer(tmp_path: Path): + return ObjectSerializerDisk[MockDataclass](tmp_path) + + +@pytest.fixture +def fwd_cache(tmp_path: Path): + return ObjectSerializerForwardCache(ObjectSerializerDisk[MockDataclass](tmp_path), max_cache_size=2) + + +def test_obj_serializer_disk_initializes(tmp_path: Path): + obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path) + assert obj_serializer._output_dir == tmp_path + + +def test_obj_serializer_disk_saves(obj_serializer: ObjectSerializerDisk[MockDataclass]): + obj_1 = MockDataclass(foo="bar") + obj_1_name = obj_serializer.save(obj_1) + assert Path(obj_serializer._output_dir, obj_1_name).exists() + + obj_2 = MockDataclass(foo="baz") + obj_2_name = obj_serializer.save(obj_2) + assert Path(obj_serializer._output_dir, obj_2_name).exists() + + +def test_obj_serializer_disk_loads(obj_serializer: ObjectSerializerDisk[MockDataclass]): + obj_1 = MockDataclass(foo="bar") + obj_1_name = obj_serializer.save(obj_1) + assert obj_serializer.load(obj_1_name).foo == "bar" + + obj_2 = MockDataclass(foo="baz") + obj_2_name = obj_serializer.save(obj_2) + assert obj_serializer.load(obj_2_name).foo == "baz" + + with pytest.raises(ObjectNotFoundError): + obj_serializer.load("nonexistent_object_name") + + +def test_obj_serializer_disk_deletes(obj_serializer: ObjectSerializerDisk[MockDataclass]): + obj_1 = MockDataclass(foo="bar") + obj_1_name = obj_serializer.save(obj_1) + + obj_2 = MockDataclass(foo="bar") + obj_2_name = obj_serializer.save(obj_2) + + obj_serializer.delete(obj_1_name) + assert not Path(obj_serializer._output_dir, obj_1_name).exists() + assert Path(obj_serializer._output_dir, obj_2_name).exists() + + +def test_obj_serializer_ephemeral_creates_tempdir(tmp_path: Path): + obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True) + assert isinstance(obj_serializer._tempdir, tempfile.TemporaryDirectory) + assert obj_serializer._base_output_dir == tmp_path + assert obj_serializer._output_dir != tmp_path + assert obj_serializer._output_dir == Path(obj_serializer._tempdir.name) + + +def test_obj_serializer_ephemeral_deletes_tempdir(tmp_path: Path): + obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True) + tempdir_path = obj_serializer._output_dir + del obj_serializer + assert not tempdir_path.exists() + + +def test_obj_serializer_ephemeral_deletes_tempdir_on_stop(tmp_path: Path): + obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True) + tempdir_path = obj_serializer._output_dir + obj_serializer.stop(None) # pyright: ignore [reportArgumentType] + assert not tempdir_path.exists() + + +def test_obj_serializer_ephemeral_writes_to_tempdir(tmp_path: Path): + obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True) + obj_1 = MockDataclass(foo="bar") + obj_1_name = obj_serializer.save(obj_1) + assert Path(obj_serializer._output_dir, obj_1_name).exists() + assert not Path(tmp_path, obj_1_name).exists() + + +def test_obj_serializer_disk_different_types(tmp_path: Path): + obj_serializer_1 = ObjectSerializerDisk[MockDataclass](tmp_path) + obj_1 = MockDataclass(foo="bar") + obj_1_name = obj_serializer_1.save(obj_1) + obj_1_loaded = obj_serializer_1.load(obj_1_name) + assert obj_serializer_1._obj_class_name == "MockDataclass" + assert isinstance(obj_1_loaded, MockDataclass) + assert obj_1_loaded.foo == "bar" + assert obj_1_name.startswith("MockDataclass_") + + obj_serializer_2 = ObjectSerializerDisk[int](tmp_path) + obj_2_name = obj_serializer_2.save(9001) + assert obj_serializer_2._obj_class_name == "int" + assert obj_serializer_2.load(obj_2_name) == 9001 + assert obj_2_name.startswith("int_") + + obj_serializer_3 = ObjectSerializerDisk[str](tmp_path) + obj_3_name = obj_serializer_3.save("foo") + assert obj_serializer_3._obj_class_name == "str" + assert obj_serializer_3.load(obj_3_name) == "foo" + assert obj_3_name.startswith("str_") + + obj_serializer_4 = ObjectSerializerDisk[torch.Tensor](tmp_path) + obj_4_name = obj_serializer_4.save(torch.tensor([1, 2, 3])) + obj_4_loaded = obj_serializer_4.load(obj_4_name) + assert obj_serializer_4._obj_class_name == "Tensor" + assert isinstance(obj_4_loaded, torch.Tensor) + assert torch.equal(obj_4_loaded, torch.tensor([1, 2, 3])) + assert obj_4_name.startswith("Tensor_") + + +def test_obj_serializer_fwd_cache_initializes(obj_serializer: ObjectSerializerDisk[MockDataclass]): + fwd_cache = ObjectSerializerForwardCache(obj_serializer) + assert fwd_cache._underlying_storage == obj_serializer + + +def test_obj_serializer_fwd_cache_saves_and_loads(fwd_cache: ObjectSerializerForwardCache[MockDataclass]): + obj = MockDataclass(foo="bar") + obj_name = fwd_cache.save(obj) + obj_loaded = fwd_cache.load(obj_name) + obj_underlying = fwd_cache._underlying_storage.load(obj_name) + assert obj_loaded == obj_underlying + assert obj_loaded.foo == "bar" + + +def test_obj_serializer_fwd_cache_respects_cache_size(fwd_cache: ObjectSerializerForwardCache[MockDataclass]): + obj_1 = MockDataclass(foo="bar") + obj_1_name = fwd_cache.save(obj_1) + obj_2 = MockDataclass(foo="baz") + obj_2_name = fwd_cache.save(obj_2) + obj_3 = MockDataclass(foo="qux") + obj_3_name = fwd_cache.save(obj_3) + assert obj_1_name not in fwd_cache._cache + assert obj_2_name in fwd_cache._cache + assert obj_3_name in fwd_cache._cache + # apparently qsize is "not reliable"? + assert fwd_cache._cache_ids.qsize() == 2 + + +def test_obj_serializer_fwd_cache_calls_delete_callback(fwd_cache: ObjectSerializerForwardCache[MockDataclass]): + called_name = None + obj_1 = MockDataclass(foo="bar") + + def on_deleted(name: str): + nonlocal called_name + called_name = name + + fwd_cache.on_deleted(on_deleted) + obj_1_name = fwd_cache.save(obj_1) + fwd_cache.delete(obj_1_name) + assert called_name == obj_1_name