From aa089e81082c05433687a98005665087124cba0a Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 13 Jan 2024 15:23:06 +1100 Subject: [PATCH 001/100] tidy(nodes): move all field things to fields.py Unfortunately, this is necessary to prevent circular imports at runtime. --- invokeai/app/api/routers/images.py | 2 +- invokeai/app/api_app.py | 3 +- invokeai/app/invocations/baseinvocation.py | 453 +--------------- invokeai/app/invocations/collections.py | 3 +- invokeai/app/invocations/compel.py | 6 +- .../controlnet_image_processors.py | 6 +- invokeai/app/invocations/cv.py | 3 +- invokeai/app/invocations/facetools.py | 4 +- invokeai/app/invocations/fields.py | 501 ++++++++++++++++++ invokeai/app/invocations/image.py | 5 +- invokeai/app/invocations/infill.py | 3 +- invokeai/app/invocations/ip_adapter.py | 5 +- invokeai/app/invocations/latent.py | 7 +- invokeai/app/invocations/math.py | 4 +- invokeai/app/invocations/metadata.py | 6 +- invokeai/app/invocations/model.py | 5 +- invokeai/app/invocations/noise.py | 4 +- invokeai/app/invocations/onnx.py | 16 +- invokeai/app/invocations/param_easing.py | 3 +- invokeai/app/invocations/primitives.py | 6 +- invokeai/app/invocations/prompt.py | 3 +- invokeai/app/invocations/sdxl.py | 6 +- invokeai/app/invocations/strings.py | 4 +- invokeai/app/invocations/t2i_adapter.py | 5 +- invokeai/app/invocations/tiles.py | 5 +- invokeai/app/invocations/upscale.py | 3 +- .../services/image_files/image_files_base.py | 2 +- .../services/image_files/image_files_disk.py | 2 +- .../image_records/image_records_base.py | 2 +- .../image_records/image_records_sqlite.py | 2 +- invokeai/app/services/images/images_base.py | 2 +- .../app/services/images/images_default.py | 2 +- invokeai/app/services/shared/graph.py | 5 +- invokeai/app/shared/fields.py | 67 --- invokeai/app/shared/models.py | 2 +- tests/aa_nodes/test_nodes.py | 3 +- 36 files changed, 552 insertions(+), 608 deletions(-) create mode 100644 invokeai/app/invocations/fields.py delete mode 100644 invokeai/app/shared/fields.py diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index 125896b8d3a..cc60ad1be83 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -8,7 +8,7 @@ from PIL import Image from pydantic import BaseModel, Field, ValidationError -from invokeai.app.invocations.baseinvocation import MetadataField, MetadataFieldValidator +from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO from invokeai.app.services.shared.pagination import OffsetPaginatedResults diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 6294083d0e1..f48074de7c7 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -6,6 +6,7 @@ from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles from invokeai.version.invokeai_version import __version__ +from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra from .services.config import InvokeAIAppConfig app_config = InvokeAIAppConfig.get_config() @@ -57,8 +58,6 @@ from .api.sockets import SocketIO from .invocations.baseinvocation import ( BaseInvocation, - InputFieldJSONSchemaExtra, - OutputFieldJSONSchemaExtra, UIConfigBase, ) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index d9e0c7ba0d2..395d5e98707 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -12,10 +12,11 @@ from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union, cast import semver -from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, create_model -from pydantic.fields import FieldInfo, _Unset +from pydantic import BaseModel, ConfigDict, Field, create_model +from pydantic.fields import FieldInfo from pydantic_core import PydanticUndefined +from invokeai.app.invocations.fields import FieldKind, Input from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID from invokeai.app.shared.fields import FieldDescriptions @@ -52,393 +53,6 @@ class Classification(str, Enum, metaclass=MetaEnum): Prototype = "prototype" -class Input(str, Enum, metaclass=MetaEnum): - """ - The type of input a field accepts. - - `Input.Direct`: The field must have its value provided directly, when the invocation and field \ - are instantiated. - - `Input.Connection`: The field must have its value provided by a connection. - - `Input.Any`: The field may have its value provided either directly or by a connection. - """ - - Connection = "connection" - Direct = "direct" - Any = "any" - - -class FieldKind(str, Enum, metaclass=MetaEnum): - """ - The kind of field. - - `Input`: An input field on a node. - - `Output`: An output field on a node. - - `Internal`: A field which is treated as an input, but cannot be used in node definitions. Metadata is - one example. It is provided to nodes via the WithMetadata class, and we want to reserve the field name - "metadata" for this on all nodes. `FieldKind` is used to short-circuit the field name validation logic, - allowing "metadata" for that field. - - `NodeAttribute`: The field is a node attribute. These are fields which are not inputs or outputs, - but which are used to store information about the node. For example, the `id` and `type` fields are node - attributes. - - The presence of this in `json_schema_extra["field_kind"]` is used when initializing node schemas on app - startup, and when generating the OpenAPI schema for the workflow editor. - """ - - Input = "input" - Output = "output" - Internal = "internal" - NodeAttribute = "node_attribute" - - -class UIType(str, Enum, metaclass=MetaEnum): - """ - Type hints for the UI for situations in which the field type is not enough to infer the correct UI type. - - - Model Fields - The most common node-author-facing use will be for model fields. Internally, there is no difference - between SD-1, SD-2 and SDXL model fields - they all use the class `MainModelField`. To ensure the - base-model-specific UI is rendered, use e.g. `ui_type=UIType.SDXLMainModelField` to indicate that - the field is an SDXL main model field. - - - Any Field - We cannot infer the usage of `typing.Any` via schema parsing, so you *must* use `ui_type=UIType.Any` to - indicate that the field accepts any type. Use with caution. This cannot be used on outputs. - - - Scheduler Field - Special handling in the UI is needed for this field, which otherwise would be parsed as a plain enum field. - - - Internal Fields - Similar to the Any Field, the `collect` and `iterate` nodes make use of `typing.Any`. To facilitate - handling these types in the client, we use `UIType._Collection` and `UIType._CollectionItem`. These - should not be used by node authors. - - - DEPRECATED Fields - These types are deprecated and should not be used by node authors. A warning will be logged if one is - used, and the type will be ignored. They are included here for backwards compatibility. - """ - - # region Model Field Types - SDXLMainModel = "SDXLMainModelField" - SDXLRefinerModel = "SDXLRefinerModelField" - ONNXModel = "ONNXModelField" - VaeModel = "VAEModelField" - LoRAModel = "LoRAModelField" - ControlNetModel = "ControlNetModelField" - IPAdapterModel = "IPAdapterModelField" - # endregion - - # region Misc Field Types - Scheduler = "SchedulerField" - Any = "AnyField" - # endregion - - # region Internal Field Types - _Collection = "CollectionField" - _CollectionItem = "CollectionItemField" - # endregion - - # region DEPRECATED - Boolean = "DEPRECATED_Boolean" - Color = "DEPRECATED_Color" - Conditioning = "DEPRECATED_Conditioning" - Control = "DEPRECATED_Control" - Float = "DEPRECATED_Float" - Image = "DEPRECATED_Image" - Integer = "DEPRECATED_Integer" - Latents = "DEPRECATED_Latents" - String = "DEPRECATED_String" - BooleanCollection = "DEPRECATED_BooleanCollection" - ColorCollection = "DEPRECATED_ColorCollection" - ConditioningCollection = "DEPRECATED_ConditioningCollection" - ControlCollection = "DEPRECATED_ControlCollection" - FloatCollection = "DEPRECATED_FloatCollection" - ImageCollection = "DEPRECATED_ImageCollection" - IntegerCollection = "DEPRECATED_IntegerCollection" - LatentsCollection = "DEPRECATED_LatentsCollection" - StringCollection = "DEPRECATED_StringCollection" - BooleanPolymorphic = "DEPRECATED_BooleanPolymorphic" - ColorPolymorphic = "DEPRECATED_ColorPolymorphic" - ConditioningPolymorphic = "DEPRECATED_ConditioningPolymorphic" - ControlPolymorphic = "DEPRECATED_ControlPolymorphic" - FloatPolymorphic = "DEPRECATED_FloatPolymorphic" - ImagePolymorphic = "DEPRECATED_ImagePolymorphic" - IntegerPolymorphic = "DEPRECATED_IntegerPolymorphic" - LatentsPolymorphic = "DEPRECATED_LatentsPolymorphic" - StringPolymorphic = "DEPRECATED_StringPolymorphic" - MainModel = "DEPRECATED_MainModel" - UNet = "DEPRECATED_UNet" - Vae = "DEPRECATED_Vae" - CLIP = "DEPRECATED_CLIP" - Collection = "DEPRECATED_Collection" - CollectionItem = "DEPRECATED_CollectionItem" - Enum = "DEPRECATED_Enum" - WorkflowField = "DEPRECATED_WorkflowField" - IsIntermediate = "DEPRECATED_IsIntermediate" - BoardField = "DEPRECATED_BoardField" - MetadataItem = "DEPRECATED_MetadataItem" - MetadataItemCollection = "DEPRECATED_MetadataItemCollection" - MetadataItemPolymorphic = "DEPRECATED_MetadataItemPolymorphic" - MetadataDict = "DEPRECATED_MetadataDict" - # endregion - - -class UIComponent(str, Enum, metaclass=MetaEnum): - """ - The type of UI component to use for a field, used to override the default components, which are - inferred from the field type. - """ - - None_ = "none" - Textarea = "textarea" - Slider = "slider" - - -class InputFieldJSONSchemaExtra(BaseModel): - """ - Extra attributes to be added to input fields and their OpenAPI schema. Used during graph execution, - and by the workflow editor during schema parsing and UI rendering. - """ - - input: Input - orig_required: bool - field_kind: FieldKind - default: Optional[Any] = None - orig_default: Optional[Any] = None - ui_hidden: bool = False - ui_type: Optional[UIType] = None - ui_component: Optional[UIComponent] = None - ui_order: Optional[int] = None - ui_choice_labels: Optional[dict[str, str]] = None - - model_config = ConfigDict( - validate_assignment=True, - json_schema_serialization_defaults_required=True, - ) - - -class OutputFieldJSONSchemaExtra(BaseModel): - """ - Extra attributes to be added to input fields and their OpenAPI schema. Used by the workflow editor - during schema parsing and UI rendering. - """ - - field_kind: FieldKind - ui_hidden: bool - ui_type: Optional[UIType] - ui_order: Optional[int] - - model_config = ConfigDict( - validate_assignment=True, - json_schema_serialization_defaults_required=True, - ) - - -def InputField( - # copied from pydantic's Field - # TODO: Can we support default_factory? - default: Any = _Unset, - default_factory: Callable[[], Any] | None = _Unset, - title: str | None = _Unset, - description: str | None = _Unset, - pattern: str | None = _Unset, - strict: bool | None = _Unset, - gt: float | None = _Unset, - ge: float | None = _Unset, - lt: float | None = _Unset, - le: float | None = _Unset, - multiple_of: float | None = _Unset, - allow_inf_nan: bool | None = _Unset, - max_digits: int | None = _Unset, - decimal_places: int | None = _Unset, - min_length: int | None = _Unset, - max_length: int | None = _Unset, - # custom - input: Input = Input.Any, - ui_type: Optional[UIType] = None, - ui_component: Optional[UIComponent] = None, - ui_hidden: bool = False, - ui_order: Optional[int] = None, - ui_choice_labels: Optional[dict[str, str]] = None, -) -> Any: - """ - Creates an input field for an invocation. - - This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/latest/api/fields/#pydantic.fields.Field) \ - that adds a few extra parameters to support graph execution and the node editor UI. - - :param Input input: [Input.Any] The kind of input this field requires. \ - `Input.Direct` means a value must be provided on instantiation. \ - `Input.Connection` means the value must be provided by a connection. \ - `Input.Any` means either will do. - - :param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \ - In some situations, the field's type is not enough to infer the correct UI type. \ - For example, model selection fields should render a dropdown UI component to select a model. \ - Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \ - `MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \ - `UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field. - - :param UIComponent ui_component: [None] Optionally specifies a specific component to use in the UI. \ - The UI will always render a suitable component, but sometimes you want something different than the default. \ - For example, a `string` field will default to a single-line input, but you may want a multi-line textarea instead. \ - For this case, you could provide `UIComponent.Textarea`. - - :param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. - - :param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. - - :param dict[str, str] ui_choice_labels: [None] Specifies the labels to use for the choices in an enum field. - """ - - json_schema_extra_ = InputFieldJSONSchemaExtra( - input=input, - ui_type=ui_type, - ui_component=ui_component, - ui_hidden=ui_hidden, - ui_order=ui_order, - ui_choice_labels=ui_choice_labels, - field_kind=FieldKind.Input, - orig_required=True, - ) - - """ - There is a conflict between the typing of invocation definitions and the typing of an invocation's - `invoke()` function. - - On instantiation of a node, the invocation definition is used to create the python class. At this time, - any number of fields may be optional, because they may be provided by connections. - - On calling of `invoke()`, however, those fields may be required. - - For example, consider an ResizeImageInvocation with an `image: ImageField` field. - - `image` is required during the call to `invoke()`, but when the python class is instantiated, - the field may not be present. This is fine, because that image field will be provided by a - connection from an ancestor node, which outputs an image. - - This means we want to type the `image` field as optional for the node class definition, but required - for the `invoke()` function. - - If we use `typing.Optional` in the node class definition, the field will be typed as optional in the - `invoke()` method, and we'll have to do a lot of runtime checks to ensure the field is present - or - any static type analysis tools will complain. - - To get around this, in node class definitions, we type all fields correctly for the `invoke()` function, - but secretly make them optional in `InputField()`. We also store the original required bool and/or default - value. When we call `invoke()`, we use this stored information to do an additional check on the class. - """ - - if default_factory is not _Unset and default_factory is not None: - default = default_factory() - logger.warn('"default_factory" is not supported, calling it now to set "default"') - - # These are the args we may wish pass to the pydantic `Field()` function - field_args = { - "default": default, - "title": title, - "description": description, - "pattern": pattern, - "strict": strict, - "gt": gt, - "ge": ge, - "lt": lt, - "le": le, - "multiple_of": multiple_of, - "allow_inf_nan": allow_inf_nan, - "max_digits": max_digits, - "decimal_places": decimal_places, - "min_length": min_length, - "max_length": max_length, - } - - # We only want to pass the args that were provided, otherwise the `Field()`` function won't work as expected - provided_args = {k: v for (k, v) in field_args.items() if v is not PydanticUndefined} - - # Because we are manually making fields optional, we need to store the original required bool for reference later - json_schema_extra_.orig_required = default is PydanticUndefined - - # Make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one - if input is Input.Any or input is Input.Connection: - default_ = None if default is PydanticUndefined else default - provided_args.update({"default": default_}) - if default is not PydanticUndefined: - # Before invoking, we'll check for the original default value and set it on the field if the field has no value - json_schema_extra_.default = default - json_schema_extra_.orig_default = default - elif default is not PydanticUndefined: - default_ = default - provided_args.update({"default": default_}) - json_schema_extra_.orig_default = default_ - - return Field( - **provided_args, - json_schema_extra=json_schema_extra_.model_dump(exclude_none=True), - ) - - -def OutputField( - # copied from pydantic's Field - default: Any = _Unset, - title: str | None = _Unset, - description: str | None = _Unset, - pattern: str | None = _Unset, - strict: bool | None = _Unset, - gt: float | None = _Unset, - ge: float | None = _Unset, - lt: float | None = _Unset, - le: float | None = _Unset, - multiple_of: float | None = _Unset, - allow_inf_nan: bool | None = _Unset, - max_digits: int | None = _Unset, - decimal_places: int | None = _Unset, - min_length: int | None = _Unset, - max_length: int | None = _Unset, - # custom - ui_type: Optional[UIType] = None, - ui_hidden: bool = False, - ui_order: Optional[int] = None, -) -> Any: - """ - Creates an output field for an invocation output. - - This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \ - that adds a few extra parameters to support graph execution and the node editor UI. - - :param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \ - In some situations, the field's type is not enough to infer the correct UI type. \ - For example, model selection fields should render a dropdown UI component to select a model. \ - Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \ - `MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \ - `UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field. - - :param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \ - - :param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \ - """ - return Field( - default=default, - title=title, - description=description, - pattern=pattern, - strict=strict, - gt=gt, - ge=ge, - lt=lt, - le=le, - multiple_of=multiple_of, - allow_inf_nan=allow_inf_nan, - max_digits=max_digits, - decimal_places=decimal_places, - min_length=min_length, - max_length=max_length, - json_schema_extra=OutputFieldJSONSchemaExtra( - ui_type=ui_type, - ui_hidden=ui_hidden, - ui_order=ui_order, - field_kind=FieldKind.Output, - ).model_dump(exclude_none=True), - ) - - class UIConfigBase(BaseModel): """ Provides additional node configuration to the UI. @@ -460,33 +74,6 @@ class UIConfigBase(BaseModel): ) -class InvocationContext: - """Initialized and provided to on execution of invocations.""" - - services: InvocationServices - graph_execution_state_id: str - queue_id: str - queue_item_id: int - queue_batch_id: str - workflow: Optional[WorkflowWithoutID] - - def __init__( - self, - services: InvocationServices, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - workflow: Optional[WorkflowWithoutID], - ): - self.services = services - self.graph_execution_state_id = graph_execution_state_id - self.queue_id = queue_id - self.queue_item_id = queue_item_id - self.queue_batch_id = queue_batch_id - self.workflow = workflow - - class BaseInvocationOutput(BaseModel): """ Base class for all invocation outputs. @@ -926,37 +513,3 @@ def wrapper(cls: Type[TBaseInvocationOutput]) -> Type[TBaseInvocationOutput]: return cls return wrapper - - -class MetadataField(RootModel): - """ - Pydantic model for metadata with custom root of type dict[str, Any]. - Metadata is stored without a strict schema. - """ - - root: dict[str, Any] = Field(description="The metadata") - - -MetadataFieldValidator = TypeAdapter(MetadataField) - - -class WithMetadata(BaseModel): - metadata: Optional[MetadataField] = Field( - default=None, - description=FieldDescriptions.metadata, - json_schema_extra=InputFieldJSONSchemaExtra( - field_kind=FieldKind.Internal, - input=Input.Connection, - orig_required=False, - ).model_dump(exclude_none=True), - ) - - -class WithWorkflow: - workflow = None - - def __init_subclass__(cls) -> None: - logger.warn( - f"{cls.__module__.split('.')[0]}.{cls.__name__}: WithWorkflow is deprecated. Use `context.workflow` to access the workflow." - ) - super().__init_subclass__() diff --git a/invokeai/app/invocations/collections.py b/invokeai/app/invocations/collections.py index 4c7b6f94cd4..d35a9d79c74 100644 --- a/invokeai/app/invocations/collections.py +++ b/invokeai/app/invocations/collections.py @@ -7,7 +7,8 @@ from invokeai.app.invocations.primitives import IntegerCollectionOutput from invokeai.app.util.misc import SEED_MAX -from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation +from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .fields import InputField @invocation( diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 49c62cff564..b386aef2cbe 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -5,8 +5,8 @@ from compel import Compel, ReturnedEmbeddingsType from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput -from invokeai.app.shared.fields import FieldDescriptions from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, ExtraConditioningInfo, @@ -20,11 +20,7 @@ from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - Input, - InputField, InvocationContext, - OutputField, - UIComponent, invocation, invocation_output, ) diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 1f9342985a0..9b652b8eee9 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -25,10 +25,10 @@ from PIL import Image from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, WithMetadata from invokeai.app.invocations.primitives import ImageField, ImageOutput from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin -from invokeai.app.shared.fields import FieldDescriptions from invokeai.backend.image_util.depth_anything import DepthAnythingDetector from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector @@ -36,11 +36,7 @@ from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - Input, - InputField, InvocationContext, - OutputField, - WithMetadata, invocation, invocation_output, ) diff --git a/invokeai/app/invocations/cv.py b/invokeai/app/invocations/cv.py index cb6828d21ac..5865338e192 100644 --- a/invokeai/app/invocations/cv.py +++ b/invokeai/app/invocations/cv.py @@ -8,7 +8,8 @@ from invokeai.app.invocations.primitives import ImageField, ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin -from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, invocation +from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .fields import InputField, WithMetadata @invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.2.0") diff --git a/invokeai/app/invocations/facetools.py b/invokeai/app/invocations/facetools.py index e0c89b4de5a..13f1066ec3e 100644 --- a/invokeai/app/invocations/facetools.py +++ b/invokeai/app/invocations/facetools.py @@ -13,13 +13,11 @@ import invokeai.assets.fonts as font_assets from invokeai.app.invocations.baseinvocation import ( BaseInvocation, - InputField, InvocationContext, - OutputField, - WithMetadata, invocation, invocation_output, ) +from invokeai.app.invocations.fields import InputField, OutputField, WithMetadata from invokeai.app.invocations.primitives import ImageField, ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py new file mode 100644 index 00000000000..0cce8e3c6b5 --- /dev/null +++ b/invokeai/app/invocations/fields.py @@ -0,0 +1,501 @@ +from enum import Enum +from typing import Any, Callable, Optional + +from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter +from pydantic.fields import _Unset +from pydantic_core import PydanticUndefined + +from invokeai.app.util.metaenum import MetaEnum +from invokeai.backend.util.logging import InvokeAILogger + +logger = InvokeAILogger.get_logger() + + +class UIType(str, Enum, metaclass=MetaEnum): + """ + Type hints for the UI for situations in which the field type is not enough to infer the correct UI type. + + - Model Fields + The most common node-author-facing use will be for model fields. Internally, there is no difference + between SD-1, SD-2 and SDXL model fields - they all use the class `MainModelField`. To ensure the + base-model-specific UI is rendered, use e.g. `ui_type=UIType.SDXLMainModelField` to indicate that + the field is an SDXL main model field. + + - Any Field + We cannot infer the usage of `typing.Any` via schema parsing, so you *must* use `ui_type=UIType.Any` to + indicate that the field accepts any type. Use with caution. This cannot be used on outputs. + + - Scheduler Field + Special handling in the UI is needed for this field, which otherwise would be parsed as a plain enum field. + + - Internal Fields + Similar to the Any Field, the `collect` and `iterate` nodes make use of `typing.Any`. To facilitate + handling these types in the client, we use `UIType._Collection` and `UIType._CollectionItem`. These + should not be used by node authors. + + - DEPRECATED Fields + These types are deprecated and should not be used by node authors. A warning will be logged if one is + used, and the type will be ignored. They are included here for backwards compatibility. + """ + + # region Model Field Types + SDXLMainModel = "SDXLMainModelField" + SDXLRefinerModel = "SDXLRefinerModelField" + ONNXModel = "ONNXModelField" + VaeModel = "VAEModelField" + LoRAModel = "LoRAModelField" + ControlNetModel = "ControlNetModelField" + IPAdapterModel = "IPAdapterModelField" + # endregion + + # region Misc Field Types + Scheduler = "SchedulerField" + Any = "AnyField" + # endregion + + # region Internal Field Types + _Collection = "CollectionField" + _CollectionItem = "CollectionItemField" + # endregion + + # region DEPRECATED + Boolean = "DEPRECATED_Boolean" + Color = "DEPRECATED_Color" + Conditioning = "DEPRECATED_Conditioning" + Control = "DEPRECATED_Control" + Float = "DEPRECATED_Float" + Image = "DEPRECATED_Image" + Integer = "DEPRECATED_Integer" + Latents = "DEPRECATED_Latents" + String = "DEPRECATED_String" + BooleanCollection = "DEPRECATED_BooleanCollection" + ColorCollection = "DEPRECATED_ColorCollection" + ConditioningCollection = "DEPRECATED_ConditioningCollection" + ControlCollection = "DEPRECATED_ControlCollection" + FloatCollection = "DEPRECATED_FloatCollection" + ImageCollection = "DEPRECATED_ImageCollection" + IntegerCollection = "DEPRECATED_IntegerCollection" + LatentsCollection = "DEPRECATED_LatentsCollection" + StringCollection = "DEPRECATED_StringCollection" + BooleanPolymorphic = "DEPRECATED_BooleanPolymorphic" + ColorPolymorphic = "DEPRECATED_ColorPolymorphic" + ConditioningPolymorphic = "DEPRECATED_ConditioningPolymorphic" + ControlPolymorphic = "DEPRECATED_ControlPolymorphic" + FloatPolymorphic = "DEPRECATED_FloatPolymorphic" + ImagePolymorphic = "DEPRECATED_ImagePolymorphic" + IntegerPolymorphic = "DEPRECATED_IntegerPolymorphic" + LatentsPolymorphic = "DEPRECATED_LatentsPolymorphic" + StringPolymorphic = "DEPRECATED_StringPolymorphic" + MainModel = "DEPRECATED_MainModel" + UNet = "DEPRECATED_UNet" + Vae = "DEPRECATED_Vae" + CLIP = "DEPRECATED_CLIP" + Collection = "DEPRECATED_Collection" + CollectionItem = "DEPRECATED_CollectionItem" + Enum = "DEPRECATED_Enum" + WorkflowField = "DEPRECATED_WorkflowField" + IsIntermediate = "DEPRECATED_IsIntermediate" + BoardField = "DEPRECATED_BoardField" + MetadataItem = "DEPRECATED_MetadataItem" + MetadataItemCollection = "DEPRECATED_MetadataItemCollection" + MetadataItemPolymorphic = "DEPRECATED_MetadataItemPolymorphic" + MetadataDict = "DEPRECATED_MetadataDict" + + +class UIComponent(str, Enum, metaclass=MetaEnum): + """ + The type of UI component to use for a field, used to override the default components, which are + inferred from the field type. + """ + + None_ = "none" + Textarea = "textarea" + Slider = "slider" + + +class FieldDescriptions: + denoising_start = "When to start denoising, expressed a percentage of total steps" + denoising_end = "When to stop denoising, expressed a percentage of total steps" + cfg_scale = "Classifier-Free Guidance scale" + cfg_rescale_multiplier = "Rescale multiplier for CFG guidance, used for models trained with zero-terminal SNR" + scheduler = "Scheduler to use during inference" + positive_cond = "Positive conditioning tensor" + negative_cond = "Negative conditioning tensor" + noise = "Noise tensor" + clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count" + unet = "UNet (scheduler, LoRAs)" + vae = "VAE" + cond = "Conditioning tensor" + controlnet_model = "ControlNet model to load" + vae_model = "VAE model to load" + lora_model = "LoRA model to load" + main_model = "Main model (UNet, VAE, CLIP) to load" + sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load" + sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load" + onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load" + lora_weight = "The weight at which the LoRA is applied to each model" + compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor" + raw_prompt = "Raw prompt text (no parsing)" + sdxl_aesthetic = "The aesthetic score to apply to the conditioning tensor" + skipped_layers = "Number of layers to skip in text encoder" + seed = "Seed for random number generation" + steps = "Number of steps to run" + width = "Width of output (px)" + height = "Height of output (px)" + control = "ControlNet(s) to apply" + ip_adapter = "IP-Adapter to apply" + t2i_adapter = "T2I-Adapter(s) to apply" + denoised_latents = "Denoised latents tensor" + latents = "Latents tensor" + strength = "Strength of denoising (proportional to steps)" + metadata = "Optional metadata to be saved with the image" + metadata_collection = "Collection of Metadata" + metadata_item_polymorphic = "A single metadata item or collection of metadata items" + metadata_item_label = "Label for this metadata item" + metadata_item_value = "The value for this metadata item (may be any type)" + workflow = "Optional workflow to be saved with the image" + interp_mode = "Interpolation mode" + torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)" + fp32 = "Whether or not to use full float32 precision" + precision = "Precision to use" + tiled = "Processing using overlapping tiles (reduce memory consumption)" + detect_res = "Pixel resolution for detection" + image_res = "Pixel resolution for output image" + safe_mode = "Whether or not to use safe mode" + scribble_mode = "Whether or not to use scribble mode" + scale_factor = "The factor by which to scale" + blend_alpha = ( + "Blending factor. 0.0 = use input A only, 1.0 = use input B only, 0.5 = 50% mix of input A and input B." + ) + num_1 = "The first number" + num_2 = "The second number" + mask = "The mask to use for the operation" + board = "The board to save the image to" + image = "The image to process" + tile_size = "Tile size" + inclusive_low = "The inclusive low value" + exclusive_high = "The exclusive high value" + decimal_places = "The number of decimal places to round to" + freeu_s1 = 'Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.' + freeu_s2 = 'Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.' + freeu_b1 = "Scaling factor for stage 1 to amplify the contributions of backbone features." + freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features." + + +class MetadataField(RootModel): + """ + Pydantic model for metadata with custom root of type dict[str, Any]. + Metadata is stored without a strict schema. + """ + + root: dict[str, Any] = Field(description="The metadata") + + +MetadataFieldValidator = TypeAdapter(MetadataField) + + +class Input(str, Enum, metaclass=MetaEnum): + """ + The type of input a field accepts. + - `Input.Direct`: The field must have its value provided directly, when the invocation and field \ + are instantiated. + - `Input.Connection`: The field must have its value provided by a connection. + - `Input.Any`: The field may have its value provided either directly or by a connection. + """ + + Connection = "connection" + Direct = "direct" + Any = "any" + + +class FieldKind(str, Enum, metaclass=MetaEnum): + """ + The kind of field. + - `Input`: An input field on a node. + - `Output`: An output field on a node. + - `Internal`: A field which is treated as an input, but cannot be used in node definitions. Metadata is + one example. It is provided to nodes via the WithMetadata class, and we want to reserve the field name + "metadata" for this on all nodes. `FieldKind` is used to short-circuit the field name validation logic, + allowing "metadata" for that field. + - `NodeAttribute`: The field is a node attribute. These are fields which are not inputs or outputs, + but which are used to store information about the node. For example, the `id` and `type` fields are node + attributes. + + The presence of this in `json_schema_extra["field_kind"]` is used when initializing node schemas on app + startup, and when generating the OpenAPI schema for the workflow editor. + """ + + Input = "input" + Output = "output" + Internal = "internal" + NodeAttribute = "node_attribute" + + +class InputFieldJSONSchemaExtra(BaseModel): + """ + Extra attributes to be added to input fields and their OpenAPI schema. Used during graph execution, + and by the workflow editor during schema parsing and UI rendering. + """ + + input: Input + orig_required: bool + field_kind: FieldKind + default: Optional[Any] = None + orig_default: Optional[Any] = None + ui_hidden: bool = False + ui_type: Optional[UIType] = None + ui_component: Optional[UIComponent] = None + ui_order: Optional[int] = None + ui_choice_labels: Optional[dict[str, str]] = None + + model_config = ConfigDict( + validate_assignment=True, + json_schema_serialization_defaults_required=True, + ) + + +class WithMetadata(BaseModel): + metadata: Optional[MetadataField] = Field( + default=None, + description=FieldDescriptions.metadata, + json_schema_extra=InputFieldJSONSchemaExtra( + field_kind=FieldKind.Internal, + input=Input.Connection, + orig_required=False, + ).model_dump(exclude_none=True), + ) + + +class WithWorkflow: + workflow = None + + def __init_subclass__(cls) -> None: + logger.warn( + f"{cls.__module__.split('.')[0]}.{cls.__name__}: WithWorkflow is deprecated. Use `context.workflow` to access the workflow." + ) + super().__init_subclass__() + + +class OutputFieldJSONSchemaExtra(BaseModel): + """ + Extra attributes to be added to input fields and their OpenAPI schema. Used by the workflow editor + during schema parsing and UI rendering. + """ + + field_kind: FieldKind + ui_hidden: bool + ui_type: Optional[UIType] + ui_order: Optional[int] + + model_config = ConfigDict( + validate_assignment=True, + json_schema_serialization_defaults_required=True, + ) + + +def InputField( + # copied from pydantic's Field + # TODO: Can we support default_factory? + default: Any = _Unset, + default_factory: Callable[[], Any] | None = _Unset, + title: str | None = _Unset, + description: str | None = _Unset, + pattern: str | None = _Unset, + strict: bool | None = _Unset, + gt: float | None = _Unset, + ge: float | None = _Unset, + lt: float | None = _Unset, + le: float | None = _Unset, + multiple_of: float | None = _Unset, + allow_inf_nan: bool | None = _Unset, + max_digits: int | None = _Unset, + decimal_places: int | None = _Unset, + min_length: int | None = _Unset, + max_length: int | None = _Unset, + # custom + input: Input = Input.Any, + ui_type: Optional[UIType] = None, + ui_component: Optional[UIComponent] = None, + ui_hidden: bool = False, + ui_order: Optional[int] = None, + ui_choice_labels: Optional[dict[str, str]] = None, +) -> Any: + """ + Creates an input field for an invocation. + + This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/latest/api/fields/#pydantic.fields.Field) \ + that adds a few extra parameters to support graph execution and the node editor UI. + + :param Input input: [Input.Any] The kind of input this field requires. \ + `Input.Direct` means a value must be provided on instantiation. \ + `Input.Connection` means the value must be provided by a connection. \ + `Input.Any` means either will do. + + :param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \ + In some situations, the field's type is not enough to infer the correct UI type. \ + For example, model selection fields should render a dropdown UI component to select a model. \ + Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \ + `MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \ + `UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field. + + :param UIComponent ui_component: [None] Optionally specifies a specific component to use in the UI. \ + The UI will always render a suitable component, but sometimes you want something different than the default. \ + For example, a `string` field will default to a single-line input, but you may want a multi-line textarea instead. \ + For this case, you could provide `UIComponent.Textarea`. + + :param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. + + :param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. + + :param dict[str, str] ui_choice_labels: [None] Specifies the labels to use for the choices in an enum field. + """ + + json_schema_extra_ = InputFieldJSONSchemaExtra( + input=input, + ui_type=ui_type, + ui_component=ui_component, + ui_hidden=ui_hidden, + ui_order=ui_order, + ui_choice_labels=ui_choice_labels, + field_kind=FieldKind.Input, + orig_required=True, + ) + + """ + There is a conflict between the typing of invocation definitions and the typing of an invocation's + `invoke()` function. + + On instantiation of a node, the invocation definition is used to create the python class. At this time, + any number of fields may be optional, because they may be provided by connections. + + On calling of `invoke()`, however, those fields may be required. + + For example, consider an ResizeImageInvocation with an `image: ImageField` field. + + `image` is required during the call to `invoke()`, but when the python class is instantiated, + the field may not be present. This is fine, because that image field will be provided by a + connection from an ancestor node, which outputs an image. + + This means we want to type the `image` field as optional for the node class definition, but required + for the `invoke()` function. + + If we use `typing.Optional` in the node class definition, the field will be typed as optional in the + `invoke()` method, and we'll have to do a lot of runtime checks to ensure the field is present - or + any static type analysis tools will complain. + + To get around this, in node class definitions, we type all fields correctly for the `invoke()` function, + but secretly make them optional in `InputField()`. We also store the original required bool and/or default + value. When we call `invoke()`, we use this stored information to do an additional check on the class. + """ + + if default_factory is not _Unset and default_factory is not None: + default = default_factory() + logger.warn('"default_factory" is not supported, calling it now to set "default"') + + # These are the args we may wish pass to the pydantic `Field()` function + field_args = { + "default": default, + "title": title, + "description": description, + "pattern": pattern, + "strict": strict, + "gt": gt, + "ge": ge, + "lt": lt, + "le": le, + "multiple_of": multiple_of, + "allow_inf_nan": allow_inf_nan, + "max_digits": max_digits, + "decimal_places": decimal_places, + "min_length": min_length, + "max_length": max_length, + } + + # We only want to pass the args that were provided, otherwise the `Field()`` function won't work as expected + provided_args = {k: v for (k, v) in field_args.items() if v is not PydanticUndefined} + + # Because we are manually making fields optional, we need to store the original required bool for reference later + json_schema_extra_.orig_required = default is PydanticUndefined + + # Make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one + if input is Input.Any or input is Input.Connection: + default_ = None if default is PydanticUndefined else default + provided_args.update({"default": default_}) + if default is not PydanticUndefined: + # Before invoking, we'll check for the original default value and set it on the field if the field has no value + json_schema_extra_.default = default + json_schema_extra_.orig_default = default + elif default is not PydanticUndefined: + default_ = default + provided_args.update({"default": default_}) + json_schema_extra_.orig_default = default_ + + return Field( + **provided_args, + json_schema_extra=json_schema_extra_.model_dump(exclude_none=True), + ) + + +def OutputField( + # copied from pydantic's Field + default: Any = _Unset, + title: str | None = _Unset, + description: str | None = _Unset, + pattern: str | None = _Unset, + strict: bool | None = _Unset, + gt: float | None = _Unset, + ge: float | None = _Unset, + lt: float | None = _Unset, + le: float | None = _Unset, + multiple_of: float | None = _Unset, + allow_inf_nan: bool | None = _Unset, + max_digits: int | None = _Unset, + decimal_places: int | None = _Unset, + min_length: int | None = _Unset, + max_length: int | None = _Unset, + # custom + ui_type: Optional[UIType] = None, + ui_hidden: bool = False, + ui_order: Optional[int] = None, +) -> Any: + """ + Creates an output field for an invocation output. + + This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \ + that adds a few extra parameters to support graph execution and the node editor UI. + + :param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \ + In some situations, the field's type is not enough to infer the correct UI type. \ + For example, model selection fields should render a dropdown UI component to select a model. \ + Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \ + `MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \ + `UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field. + + :param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \ + + :param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \ + """ + return Field( + default=default, + title=title, + description=description, + pattern=pattern, + strict=strict, + gt=gt, + ge=ge, + lt=lt, + le=le, + multiple_of=multiple_of, + allow_inf_nan=allow_inf_nan, + max_digits=max_digits, + decimal_places=decimal_places, + min_length=min_length, + max_length=max_length, + json_schema_extra=OutputFieldJSONSchemaExtra( + ui_type=ui_type, + ui_hidden=ui_hidden, + ui_order=ui_order, + field_kind=FieldKind.Output, + ).model_dump(exclude_none=True), + ) + # endregion diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index f729d60cdd5..16d0f33dda3 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -7,19 +7,16 @@ import numpy from PIL import Image, ImageChops, ImageFilter, ImageOps +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, WithMetadata from invokeai.app.invocations.primitives import BoardField, ColorField, ImageField, ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin -from invokeai.app.shared.fields import FieldDescriptions from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark from invokeai.backend.image_util.safety_checker import SafetyChecker from .baseinvocation import ( BaseInvocation, Classification, - Input, - InputField, InvocationContext, - WithMetadata, invocation, ) diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index c3d00bb1330..d4d3d5bea44 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -13,7 +13,8 @@ from invokeai.backend.image_util.lama import LaMA from invokeai.backend.image_util.patchmatch import PatchMatch -from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, invocation +from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .fields import InputField, WithMetadata from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index 6bd28896244..c01e0ed0fb2 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -7,16 +7,13 @@ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - Input, - InputField, InvocationContext, - OutputField, invocation, invocation_output, ) +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights -from invokeai.app.shared.fields import FieldDescriptions from invokeai.backend.model_management.models.base import BaseModelType, ModelType from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index b77363ceb86..909c307481e 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -23,6 +23,7 @@ from pydantic import field_validator from torchvision.transforms.functional import resize as tv_resize +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType, WithMetadata from invokeai.app.invocations.ip_adapter import IPAdapterField from invokeai.app.invocations.primitives import ( DenoiseMaskField, @@ -35,7 +36,6 @@ ) from invokeai.app.invocations.t2i_adapter import T2IAdapterField from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin -from invokeai.app.shared.fields import FieldDescriptions from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus @@ -59,12 +59,7 @@ from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - Input, - InputField, InvocationContext, - OutputField, - UIType, - WithMetadata, invocation, invocation_output, ) diff --git a/invokeai/app/invocations/math.py b/invokeai/app/invocations/math.py index defc61275fe..6ca53011f0b 100644 --- a/invokeai/app/invocations/math.py +++ b/invokeai/app/invocations/math.py @@ -5,10 +5,10 @@ import numpy as np from pydantic import ValidationInfo, field_validator +from invokeai.app.invocations.fields import FieldDescriptions, InputField from invokeai.app.invocations.primitives import FloatOutput, IntegerOutput -from invokeai.app.shared.fields import FieldDescriptions -from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation +from .baseinvocation import BaseInvocation, InvocationContext, invocation @invocation("add", title="Add Integers", tags=["math", "add"], category="math", version="1.0.0") diff --git a/invokeai/app/invocations/metadata.py b/invokeai/app/invocations/metadata.py index 14d66f8ef68..399e217dc17 100644 --- a/invokeai/app/invocations/metadata.py +++ b/invokeai/app/invocations/metadata.py @@ -5,20 +5,16 @@ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InputField, InvocationContext, - MetadataField, - OutputField, - UIType, invocation, invocation_output, ) from invokeai.app.invocations.controlnet_image_processors import ControlField +from invokeai.app.invocations.fields import FieldDescriptions, InputField, MetadataField, OutputField, UIType from invokeai.app.invocations.ip_adapter import IPAdapterModelField from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.t2i_adapter import T2IAdapterField -from invokeai.app.shared.fields import FieldDescriptions from ...version import __version__ diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 99dcc72999b..c710c9761b0 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -3,17 +3,14 @@ from pydantic import BaseModel, ConfigDict, Field -from invokeai.app.shared.fields import FieldDescriptions +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField from invokeai.app.shared.models import FreeUConfig from ...backend.model_management import BaseModelType, ModelType, SubModelType from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - Input, - InputField, InvocationContext, - OutputField, invocation, invocation_output, ) diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index b1ee91e1cdf..2e717ac561b 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -4,17 +4,15 @@ import torch from pydantic import field_validator +from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField from invokeai.app.invocations.latent import LatentsField -from invokeai.app.shared.fields import FieldDescriptions from invokeai.app.util.misc import SEED_MAX from ...backend.util.devices import choose_torch_device, torch_dtype from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InputField, InvocationContext, - OutputField, invocation, invocation_output, ) diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py index 759cfde700f..b43d7eaef2c 100644 --- a/invokeai/app/invocations/onnx.py +++ b/invokeai/app/invocations/onnx.py @@ -11,9 +11,17 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator from tqdm import tqdm +from invokeai.app.invocations.fields import ( + FieldDescriptions, + Input, + InputField, + OutputField, + UIComponent, + UIType, + WithMetadata, +) from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin -from invokeai.app.shared.fields import FieldDescriptions from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.backend import BaseModelType, ModelType, SubModelType @@ -24,13 +32,7 @@ from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - Input, - InputField, InvocationContext, - OutputField, - UIComponent, - UIType, - WithMetadata, invocation, invocation_output, ) diff --git a/invokeai/app/invocations/param_easing.py b/invokeai/app/invocations/param_easing.py index dccd18f754b..dab9c3dc0f4 100644 --- a/invokeai/app/invocations/param_easing.py +++ b/invokeai/app/invocations/param_easing.py @@ -41,7 +41,8 @@ from invokeai.app.invocations.primitives import FloatCollectionOutput -from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation +from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .fields import InputField @invocation( diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index afe8ff06d9d..22f03454a55 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -5,16 +5,12 @@ import torch from pydantic import BaseModel, Field -from invokeai.app.shared.fields import FieldDescriptions +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - Input, - InputField, InvocationContext, - OutputField, - UIComponent, invocation, invocation_output, ) diff --git a/invokeai/app/invocations/prompt.py b/invokeai/app/invocations/prompt.py index 4778d980771..94b4a217ae7 100644 --- a/invokeai/app/invocations/prompt.py +++ b/invokeai/app/invocations/prompt.py @@ -7,7 +7,8 @@ from invokeai.app.invocations.primitives import StringCollectionOutput -from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, invocation +from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .fields import InputField, UIComponent @invocation( diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 68076fdfeb1..62df5bc8047 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -1,14 +1,10 @@ -from invokeai.app.shared.fields import FieldDescriptions +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType from ...backend.model_management import ModelType, SubModelType from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - Input, - InputField, InvocationContext, - OutputField, - UIType, invocation, invocation_output, ) diff --git a/invokeai/app/invocations/strings.py b/invokeai/app/invocations/strings.py index 3466206b377..ccbc2f6d924 100644 --- a/invokeai/app/invocations/strings.py +++ b/invokeai/app/invocations/strings.py @@ -5,13 +5,11 @@ from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InputField, InvocationContext, - OutputField, - UIComponent, invocation, invocation_output, ) +from .fields import InputField, OutputField, UIComponent from .primitives import StringOutput diff --git a/invokeai/app/invocations/t2i_adapter.py b/invokeai/app/invocations/t2i_adapter.py index e055d23903f..66ac87c37b8 100644 --- a/invokeai/app/invocations/t2i_adapter.py +++ b/invokeai/app/invocations/t2i_adapter.py @@ -5,17 +5,14 @@ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - Input, - InputField, InvocationContext, - OutputField, invocation, invocation_output, ) from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights -from invokeai.app.shared.fields import FieldDescriptions from invokeai.backend.model_management.models.base import BaseModelType diff --git a/invokeai/app/invocations/tiles.py b/invokeai/app/invocations/tiles.py index e51f891a8db..bdc23ef6edd 100644 --- a/invokeai/app/invocations/tiles.py +++ b/invokeai/app/invocations/tiles.py @@ -8,14 +8,11 @@ BaseInvocation, BaseInvocationOutput, Classification, - Input, - InputField, InvocationContext, - OutputField, - WithMetadata, invocation, invocation_output, ) +from invokeai.app.invocations.fields import Input, InputField, OutputField, WithMetadata from invokeai.app.invocations.primitives import ImageField, ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin from invokeai.backend.tiles.tiles import ( diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index 5f715c1a7ed..2cab279a9fc 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -14,7 +14,8 @@ from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN from invokeai.backend.util.devices import choose_torch_device -from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, invocation +from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .fields import InputField, WithMetadata # TODO: Populate this from disk? # TODO: Use model manager to load? diff --git a/invokeai/app/services/image_files/image_files_base.py b/invokeai/app/services/image_files/image_files_base.py index 27dd67531f4..f4036277b72 100644 --- a/invokeai/app/services/image_files/image_files_base.py +++ b/invokeai/app/services/image_files/image_files_base.py @@ -4,7 +4,7 @@ from PIL.Image import Image as PILImageType -from invokeai.app.invocations.baseinvocation import MetadataField +from invokeai.app.invocations.fields import MetadataField from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID diff --git a/invokeai/app/services/image_files/image_files_disk.py b/invokeai/app/services/image_files/image_files_disk.py index 08448216723..fb687973bad 100644 --- a/invokeai/app/services/image_files/image_files_disk.py +++ b/invokeai/app/services/image_files/image_files_disk.py @@ -7,7 +7,7 @@ from PIL.Image import Image as PILImageType from send2trash import send2trash -from invokeai.app.invocations.baseinvocation import MetadataField +from invokeai.app.invocations.fields import MetadataField from invokeai.app.services.invoker import Invoker from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail diff --git a/invokeai/app/services/image_records/image_records_base.py b/invokeai/app/services/image_records/image_records_base.py index 727f4977fba..7b7b261ecab 100644 --- a/invokeai/app/services/image_records/image_records_base.py +++ b/invokeai/app/services/image_records/image_records_base.py @@ -2,7 +2,7 @@ from datetime import datetime from typing import Optional -from invokeai.app.invocations.metadata import MetadataField +from invokeai.app.invocations.fields import MetadataField from invokeai.app.services.shared.pagination import OffsetPaginatedResults from .image_records_common import ImageCategory, ImageRecord, ImageRecordChanges, ResourceOrigin diff --git a/invokeai/app/services/image_records/image_records_sqlite.py b/invokeai/app/services/image_records/image_records_sqlite.py index 74f82e7d84c..5b37913c8fd 100644 --- a/invokeai/app/services/image_records/image_records_sqlite.py +++ b/invokeai/app/services/image_records/image_records_sqlite.py @@ -3,7 +3,7 @@ from datetime import datetime from typing import Optional, Union, cast -from invokeai.app.invocations.baseinvocation import MetadataField, MetadataFieldValidator +from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase diff --git a/invokeai/app/services/images/images_base.py b/invokeai/app/services/images/images_base.py index df71dadb5b0..42c42667744 100644 --- a/invokeai/app/services/images/images_base.py +++ b/invokeai/app/services/images/images_base.py @@ -3,7 +3,7 @@ from PIL.Image import Image as PILImageType -from invokeai.app.invocations.baseinvocation import MetadataField +from invokeai.app.invocations.fields import MetadataField from invokeai.app.services.image_records.image_records_common import ( ImageCategory, ImageRecord, diff --git a/invokeai/app/services/images/images_default.py b/invokeai/app/services/images/images_default.py index ff21731a506..adeed738119 100644 --- a/invokeai/app/services/images/images_default.py +++ b/invokeai/app/services/images/images_default.py @@ -2,7 +2,7 @@ from PIL.Image import Image as PILImageType -from invokeai.app.invocations.baseinvocation import MetadataField +from invokeai.app.invocations.fields import MetadataField from invokeai.app.services.invoker import Invoker from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index 1acf165abac..ba05b050c5b 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -13,14 +13,11 @@ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - Input, - InputField, InvocationContext, - OutputField, - UIType, invocation, invocation_output, ) +from invokeai.app.invocations.fields import Input, InputField, OutputField, UIType from invokeai.app.util.misc import uuid_string # in 3.10 this would be "from types import NoneType" diff --git a/invokeai/app/shared/fields.py b/invokeai/app/shared/fields.py deleted file mode 100644 index 3e841ffbf22..00000000000 --- a/invokeai/app/shared/fields.py +++ /dev/null @@ -1,67 +0,0 @@ -class FieldDescriptions: - denoising_start = "When to start denoising, expressed a percentage of total steps" - denoising_end = "When to stop denoising, expressed a percentage of total steps" - cfg_scale = "Classifier-Free Guidance scale" - cfg_rescale_multiplier = "Rescale multiplier for CFG guidance, used for models trained with zero-terminal SNR" - scheduler = "Scheduler to use during inference" - positive_cond = "Positive conditioning tensor" - negative_cond = "Negative conditioning tensor" - noise = "Noise tensor" - clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count" - unet = "UNet (scheduler, LoRAs)" - vae = "VAE" - cond = "Conditioning tensor" - controlnet_model = "ControlNet model to load" - vae_model = "VAE model to load" - lora_model = "LoRA model to load" - main_model = "Main model (UNet, VAE, CLIP) to load" - sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load" - sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load" - onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load" - lora_weight = "The weight at which the LoRA is applied to each model" - compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor" - raw_prompt = "Raw prompt text (no parsing)" - sdxl_aesthetic = "The aesthetic score to apply to the conditioning tensor" - skipped_layers = "Number of layers to skip in text encoder" - seed = "Seed for random number generation" - steps = "Number of steps to run" - width = "Width of output (px)" - height = "Height of output (px)" - control = "ControlNet(s) to apply" - ip_adapter = "IP-Adapter to apply" - t2i_adapter = "T2I-Adapter(s) to apply" - denoised_latents = "Denoised latents tensor" - latents = "Latents tensor" - strength = "Strength of denoising (proportional to steps)" - metadata = "Optional metadata to be saved with the image" - metadata_collection = "Collection of Metadata" - metadata_item_polymorphic = "A single metadata item or collection of metadata items" - metadata_item_label = "Label for this metadata item" - metadata_item_value = "The value for this metadata item (may be any type)" - workflow = "Optional workflow to be saved with the image" - interp_mode = "Interpolation mode" - torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)" - fp32 = "Whether or not to use full float32 precision" - precision = "Precision to use" - tiled = "Processing using overlapping tiles (reduce memory consumption)" - detect_res = "Pixel resolution for detection" - image_res = "Pixel resolution for output image" - safe_mode = "Whether or not to use safe mode" - scribble_mode = "Whether or not to use scribble mode" - scale_factor = "The factor by which to scale" - blend_alpha = ( - "Blending factor. 0.0 = use input A only, 1.0 = use input B only, 0.5 = 50% mix of input A and input B." - ) - num_1 = "The first number" - num_2 = "The second number" - mask = "The mask to use for the operation" - board = "The board to save the image to" - image = "The image to process" - tile_size = "Tile size" - inclusive_low = "The inclusive low value" - exclusive_high = "The exclusive high value" - decimal_places = "The number of decimal places to round to" - freeu_s1 = 'Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.' - freeu_s2 = 'Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.' - freeu_b1 = "Scaling factor for stage 1 to amplify the contributions of backbone features." - freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features." diff --git a/invokeai/app/shared/models.py b/invokeai/app/shared/models.py index ed68cb287e3..1a11b480cc5 100644 --- a/invokeai/app/shared/models.py +++ b/invokeai/app/shared/models.py @@ -1,6 +1,6 @@ from pydantic import BaseModel, Field -from invokeai.app.shared.fields import FieldDescriptions +from invokeai.app.invocations.fields import FieldDescriptions class FreeUConfig(BaseModel): diff --git a/tests/aa_nodes/test_nodes.py b/tests/aa_nodes/test_nodes.py index bca4e1011f8..e71daad3f3a 100644 --- a/tests/aa_nodes/test_nodes.py +++ b/tests/aa_nodes/test_nodes.py @@ -3,12 +3,11 @@ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InputField, InvocationContext, - OutputField, invocation, invocation_output, ) +from invokeai.app.invocations.fields import InputField, OutputField from invokeai.app.invocations.image import ImageField From f0e60a4ba22822a429bf1108a05d17f9e73a0ddb Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 13 Jan 2024 18:02:58 +1100 Subject: [PATCH 002/100] feat(nodes): restricts invocation context power Creates a low-power `InvocationContext` with simplified methods and data. See `invocation_context.py` for detailed comments. --- .../app/services/shared/invocation_context.py | 408 ++++++++++++++++++ invokeai/app/util/step_callback.py | 39 +- 2 files changed, 434 insertions(+), 13 deletions(-) create mode 100644 invokeai/app/services/shared/invocation_context.py diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py new file mode 100644 index 00000000000..c0aaac54f87 --- /dev/null +++ b/invokeai/app/services/shared/invocation_context.py @@ -0,0 +1,408 @@ +from dataclasses import dataclass +from enum import Enum +from typing import TYPE_CHECKING, Optional + +from PIL.Image import Image +from pydantic import ConfigDict +from torch import Tensor + +from invokeai.app.invocations.compel import ConditioningFieldData +from invokeai.app.invocations.fields import MetadataField, WithMetadata +from invokeai.app.services.config.config_default import InvokeAIAppConfig +from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin +from invokeai.app.services.images.images_common import ImageDTO +from invokeai.app.services.invocation_services import InvocationServices +from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID +from invokeai.app.util.misc import uuid_string +from invokeai.app.util.step_callback import stable_diffusion_step_callback +from invokeai.backend.model_management.model_manager import ModelInfo +from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType +from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState + +if TYPE_CHECKING: + from invokeai.app.invocations.baseinvocation import BaseInvocation + +""" +The InvocationContext provides access to various services and data about the current invocation. + +We do not provide the invocation services directly, as their methods are both dangerous and +inconvenient to use. + +For example: +- The `images` service allows nodes to delete or unsafely modify existing images. +- The `configuration` service allows nodes to change the app's config at runtime. +- The `events` service allows nodes to emit arbitrary events. + +Wrapping these services provides a simpler and safer interface for nodes to use. + +When a node executes, a fresh `InvocationContext` is built for it, ensuring nodes cannot interfere +with each other. + +Note: The docstrings are in weird places, but that's where they must be to get IDEs to see them. +""" + + +@dataclass(frozen=True) +class InvocationContextData: + invocation: "BaseInvocation" + session_id: str + queue_id: str + source_node_id: str + queue_item_id: int + batch_id: str + workflow: Optional[WorkflowWithoutID] = None + + +class LoggerInterface: + def __init__(self, services: InvocationServices) -> None: + def debug(message: str) -> None: + """ + Logs a debug message. + + :param message: The message to log. + """ + services.logger.debug(message) + + def info(message: str) -> None: + """ + Logs an info message. + + :param message: The message to log. + """ + services.logger.info(message) + + def warning(message: str) -> None: + """ + Logs a warning message. + + :param message: The message to log. + """ + services.logger.warning(message) + + def error(message: str) -> None: + """ + Logs an error message. + + :param message: The message to log. + """ + services.logger.error(message) + + self.debug = debug + self.info = info + self.warning = warning + self.error = error + + +class ImagesInterface: + def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None: + def save( + image: Image, + board_id: Optional[str] = None, + image_category: ImageCategory = ImageCategory.GENERAL, + metadata: Optional[MetadataField] = None, + ) -> ImageDTO: + """ + Saves an image, returning its DTO. + + If the current queue item has a workflow, it is automatically saved with the image. + + :param image: The image to save, as a PIL image. + :param board_id: The board ID to add the image to, if it should be added. + :param image_category: The category of the image. Only the GENERAL category is added to the gallery. + :param metadata: The metadata to save with the image, if it should have any. If the invocation inherits \ + from `WithMetadata`, that metadata will be used automatically. Provide this only if you want to \ + override or provide metadata manually. + """ + + # If the invocation inherits metadata, use that. Else, use the metadata passed in. + metadata_ = ( + context_data.invocation.metadata if isinstance(context_data.invocation, WithMetadata) else metadata + ) + + return services.images.create( + image=image, + is_intermediate=context_data.invocation.is_intermediate, + image_category=image_category, + board_id=board_id, + metadata=metadata_, + image_origin=ResourceOrigin.INTERNAL, + workflow=context_data.workflow, + session_id=context_data.session_id, + node_id=context_data.invocation.id, + ) + + def get_pil(image_name: str) -> Image: + """ + Gets an image as a PIL Image object. + + :param image_name: The name of the image to get. + """ + return services.images.get_pil_image(image_name) + + def get_metadata(image_name: str) -> Optional[MetadataField]: + """ + Gets an image's metadata, if it has any. + + :param image_name: The name of the image to get the metadata for. + """ + return services.images.get_metadata(image_name) + + def get_dto(image_name: str) -> ImageDTO: + """ + Gets an image as an ImageDTO object. + + :param image_name: The name of the image to get. + """ + return services.images.get_dto(image_name) + + def update( + image_name: str, + board_id: Optional[str] = None, + is_intermediate: Optional[bool] = False, + ) -> ImageDTO: + """ + Updates an image, returning its updated DTO. + + It is not suggested to update images saved by earlier nodes, as this can cause confusion for users. + + If you use this method, you *must* return the image as an :class:`ImageOutput` for the gallery to + get the updated image. + + :param image_name: The name of the image to update. + :param board_id: The board ID to add the image to, if it should be added. + :param is_intermediate: Whether the image is an intermediate. Intermediate images aren't added to the gallery. + """ + if is_intermediate is not None: + services.images.update(image_name, ImageRecordChanges(is_intermediate=is_intermediate)) + if board_id is None: + services.board_images.remove_image_from_board(image_name) + else: + services.board_images.add_image_to_board(image_name, board_id) + return services.images.get_dto(image_name) + + self.save = save + self.get_pil = get_pil + self.get_metadata = get_metadata + self.get_dto = get_dto + self.update = update + + +class LatentsKind(str, Enum): + IMAGE = "image" + NOISE = "noise" + MASK = "mask" + MASKED_IMAGE = "masked_image" + OTHER = "other" + + +class LatentsInterface: + def __init__( + self, + services: InvocationServices, + context_data: InvocationContextData, + ) -> None: + def save(tensor: Tensor) -> str: + """ + Saves a latents tensor, returning its name. + + :param tensor: The latents tensor to save. + """ + name = f"{context_data.session_id}__{context_data.invocation.id}__{uuid_string()[:7]}" + services.latents.save( + name=name, + data=tensor, + ) + return name + + def get(latents_name: str) -> Tensor: + """ + Gets a latents tensor by name. + + :param latents_name: The name of the latents tensor to get. + """ + return services.latents.get(latents_name) + + self.save = save + self.get = get + + +class ConditioningInterface: + def __init__( + self, + services: InvocationServices, + context_data: InvocationContextData, + ) -> None: + def save(conditioning_data: ConditioningFieldData) -> str: + """ + Saves a conditioning data object, returning its name. + + :param conditioning_data: The conditioning data to save. + """ + name = f"{context_data.session_id}__{context_data.invocation.id}__{uuid_string()[:7]}__conditioning" + services.latents.save( + name=name, + data=conditioning_data, # type: ignore [arg-type] + ) + return name + + def get(conditioning_name: str) -> Tensor: + """ + Gets conditioning data by name. + + :param conditioning_name: The name of the conditioning data to get. + """ + return services.latents.get(conditioning_name) + + self.save = save + self.get = get + + +class ModelsInterface: + def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None: + def exists(model_name: str, base_model: BaseModelType, model_type: ModelType) -> bool: + """ + Checks if a model exists. + + :param model_name: The name of the model to check. + :param base_model: The base model of the model to check. + :param model_type: The type of the model to check. + """ + return services.model_manager.model_exists(model_name, base_model, model_type) + + def load( + model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None + ) -> ModelInfo: + """ + Loads a model, returning its `ModelInfo` object. + + :param model_name: The name of the model to get. + :param base_model: The base model of the model to get. + :param model_type: The type of the model to get. + :param submodel: The submodel of the model to get. + """ + return services.model_manager.get_model( + model_name, base_model, model_type, submodel, context_data=context_data + ) + + def get_info(model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: + """ + Gets a model's info, an dict-like object. + + :param model_name: The name of the model to get. + :param base_model: The base model of the model to get. + :param model_type: The type of the model to get. + """ + return services.model_manager.model_info(model_name, base_model, model_type) + + self.exists = exists + self.load = load + self.get_info = get_info + + +class ConfigInterface: + def __init__(self, services: InvocationServices) -> None: + def get() -> InvokeAIAppConfig: + """ + Gets the app's config. + """ + # The config can be changed at runtime. We don't want nodes doing this, so we make a + # frozen copy.. + config = services.configuration.get_config() + frozen_config = config.model_copy(update={"model_config": ConfigDict(frozen=True)}) + return frozen_config + + self.get = get + + +class UtilInterface: + def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None: + def sd_step_callback( + intermediate_state: PipelineIntermediateState, + base_model: BaseModelType, + ) -> None: + """ + The step callback emits a progress event with the current step, the total number of + steps, a preview image, and some other internal metadata. + + This should be called after each step of the diffusion process. + + :param intermediate_state: The intermediate state of the diffusion pipeline. + :param base_model: The base model for the current denoising step. + """ + stable_diffusion_step_callback( + context_data=context_data, + intermediate_state=intermediate_state, + base_model=base_model, + invocation_queue=services.queue, + events=services.events, + ) + + self.sd_step_callback = sd_step_callback + + +class InvocationContext: + """ + The invocation context provides access to various services and data about the current invocation. + """ + + def __init__( + self, + images: ImagesInterface, + latents: LatentsInterface, + models: ModelsInterface, + config: ConfigInterface, + logger: LoggerInterface, + data: InvocationContextData, + util: UtilInterface, + conditioning: ConditioningInterface, + ) -> None: + self.images = images + "Provides methods to save, get and update images and their metadata." + self.logger = logger + "Provides access to the app logger." + self.latents = latents + "Provides methods to save and get latents tensors, including image, noise, masks, and masked images." + self.conditioning = conditioning + "Provides methods to save and get conditioning data." + self.models = models + "Provides methods to check if a model exists, get a model, and get a model's info." + self.config = config + "Provides access to the app's config." + self.data = data + "Provides data about the current queue item and invocation." + self.util = util + "Provides utility methods." + + +def build_invocation_context( + services: InvocationServices, + context_data: InvocationContextData, +) -> InvocationContext: + """ + Builds the invocation context. This is a wrapper around the invocation services that provides + a more convenient (and less dangerous) interface for nodes to use. + + :param invocation_services: The invocation services to wrap. + :param invocation_context_data: The invocation context data. + """ + + logger = LoggerInterface(services=services) + images = ImagesInterface(services=services, context_data=context_data) + latents = LatentsInterface(services=services, context_data=context_data) + models = ModelsInterface(services=services, context_data=context_data) + config = ConfigInterface(services=services) + util = UtilInterface(services=services, context_data=context_data) + conditioning = ConditioningInterface(services=services, context_data=context_data) + + ctx = InvocationContext( + images=images, + logger=logger, + config=config, + latents=latents, + models=models, + data=context_data, + util=util, + conditioning=conditioning, + ) + + return ctx diff --git a/invokeai/app/util/step_callback.py b/invokeai/app/util/step_callback.py index f166206d528..5cc3caa9ba5 100644 --- a/invokeai/app/util/step_callback.py +++ b/invokeai/app/util/step_callback.py @@ -1,12 +1,25 @@ +from typing import Protocol + import torch from PIL import Image +from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException, ProgressImage +from invokeai.app.services.invocation_queue.invocation_queue_base import InvocationQueueABC +from invokeai.app.services.shared.invocation_context import InvocationContextData from ...backend.model_management.models import BaseModelType from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.util.util import image_to_dataURL -from ..invocations.baseinvocation import InvocationContext + + +class StepCallback(Protocol): + def __call__( + self, + intermediate_state: PipelineIntermediateState, + base_model: BaseModelType, + ) -> None: + ... def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None): @@ -25,13 +38,13 @@ def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix= def stable_diffusion_step_callback( - context: InvocationContext, + context_data: InvocationContextData, intermediate_state: PipelineIntermediateState, - node: dict, - source_node_id: str, base_model: BaseModelType, -): - if context.services.queue.is_canceled(context.graph_execution_state_id): + invocation_queue: InvocationQueueABC, + events: EventServiceBase, +) -> None: + if invocation_queue.is_canceled(context_data.session_id): raise CanceledException # Some schedulers report not only the noisy latents at the current timestep, @@ -108,13 +121,13 @@ def stable_diffusion_step_callback( dataURL = image_to_dataURL(image, image_format="JPEG") - context.services.events.emit_generator_progress( - queue_id=context.queue_id, - queue_item_id=context.queue_item_id, - queue_batch_id=context.queue_batch_id, - graph_execution_state_id=context.graph_execution_state_id, - node=node, - source_node_id=source_node_id, + events.emit_generator_progress( + queue_id=context_data.queue_id, + queue_item_id=context_data.queue_item_id, + queue_batch_id=context_data.batch_id, + graph_execution_state_id=context_data.session_id, + node_id=context_data.invocation.id, + source_node_id=context_data.source_node_id, progress_image=ProgressImage(width=width, height=height, dataURL=dataURL), step=intermediate_state.step, order=intermediate_state.order, From 97a6c6eea799a431a927677e797cb2a5c544ba56 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 13 Jan 2024 23:23:01 +1100 Subject: [PATCH 003/100] feat: add pyright config I was having issues with mypy bother over- and under-reporting certain problems. I've added a pyright config. --- pyproject.toml | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 7f4b0d77f25..d063f1ad0ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -280,3 +280,19 @@ module = [ "invokeai.frontend.install.model_install", ] #=== End: MyPy + +[tool.pyright] +include = [ + "invokeai/app/invocations/" +] +exclude = [ + "**/node_modules", + "**/__pycache__", + "invokeai/app/invocations/onnx.py", + "invokeai/app/api/routers/models.py", + "invokeai/app/services/invocation_stats/invocation_stats_default.py", + "invokeai/app/services/model_manager/model_manager_base.py", + "invokeai/app/services/model_manager/model_manager_default.py", + "invokeai/app/services/model_records/model_records_sql.py", + "invokeai/app/util/controlnet_utils.py", +] From 7e5ba2795e0bd9d90e82150899eef8b8edc34f8f Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 13 Jan 2024 23:23:16 +1100 Subject: [PATCH 004/100] feat(nodes): update all invocations to use new invocation context Update all invocations to use the new context. The changes are all fairly simple, but there are a lot of them. Supporting minor changes: - Patch bump for all nodes that use the context - Update invocation processor to provide new context - Minor change to `EventServiceBase` to accept a node's ID instead of the dict version of a node - Minor change to `ModelManagerService` to support the new wrapped context - Fanagling of imports to avoid circular dependencies --- invokeai/app/invocations/baseinvocation.py | 54 +- invokeai/app/invocations/collections.py | 8 +- invokeai/app/invocations/compel.py | 105 ++-- .../controlnet_image_processors.py | 56 +- invokeai/app/invocations/cv.py | 32 +- invokeai/app/invocations/facetools.py | 129 ++-- invokeai/app/invocations/fields.py | 57 +- invokeai/app/invocations/image.py | 586 +++++------------- invokeai/app/invocations/infill.py | 123 +--- invokeai/app/invocations/ip_adapter.py | 9 +- invokeai/app/invocations/latent.py | 238 +++---- invokeai/app/invocations/math.py | 22 +- invokeai/app/invocations/metadata.py | 19 +- invokeai/app/invocations/model.py | 29 +- invokeai/app/invocations/noise.py | 25 +- invokeai/app/invocations/onnx.py | 10 +- invokeai/app/invocations/param_easing.py | 44 +- invokeai/app/invocations/primitives.py | 136 ++-- invokeai/app/invocations/prompt.py | 6 +- invokeai/app/invocations/sdxl.py | 13 +- invokeai/app/invocations/strings.py | 11 +- invokeai/app/invocations/t2i_adapter.py | 6 +- invokeai/app/invocations/tiles.py | 38 +- invokeai/app/invocations/upscale.py | 33 +- invokeai/app/services/events/events_base.py | 4 +- .../invocation_processor_default.py | 24 +- .../model_manager/model_manager_base.py | 9 +- .../model_manager/model_manager_common.py | 0 .../model_manager/model_manager_default.py | 44 +- invokeai/app/services/shared/graph.py | 7 +- .../app/services/shared/invocation_context.py | 9 +- invokeai/app/util/step_callback.py | 23 +- 32 files changed, 717 insertions(+), 1192 deletions(-) create mode 100644 invokeai/app/services/model_manager/model_manager_common.py diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 395d5e98707..c4aed1fac5a 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -16,10 +16,16 @@ from pydantic.fields import FieldInfo from pydantic_core import PydanticUndefined -from invokeai.app.invocations.fields import FieldKind, Input +from invokeai.app.invocations.fields import ( + FieldDescriptions, + FieldKind, + Input, + InputFieldJSONSchemaExtra, + MetadataField, + logger, +) from invokeai.app.services.config.config_default import InvokeAIAppConfig -from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID -from invokeai.app.shared.fields import FieldDescriptions +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.metaenum import MetaEnum from invokeai.app.util.misc import uuid_string from invokeai.backend.util.logging import InvokeAILogger @@ -219,7 +225,7 @@ def invoke(self, context: InvocationContext) -> BaseInvocationOutput: """Invoke with provided context and return outputs.""" pass - def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput: + def invoke_internal(self, context: InvocationContext, services: "InvocationServices") -> BaseInvocationOutput: """ Internal invoke method, calls `invoke()` after some prep. Handles optional fields that are required to call `invoke()` and invocation cache. @@ -244,23 +250,23 @@ def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput: raise MissingInputException(self.model_fields["type"].default, field_name) # skip node cache codepath if it's disabled - if context.services.configuration.node_cache_size == 0: + if services.configuration.node_cache_size == 0: return self.invoke(context) output: BaseInvocationOutput if self.use_cache: - key = context.services.invocation_cache.create_key(self) - cached_value = context.services.invocation_cache.get(key) + key = services.invocation_cache.create_key(self) + cached_value = services.invocation_cache.get(key) if cached_value is None: - context.services.logger.debug(f'Invocation cache miss for type "{self.get_type()}": {self.id}') + services.logger.debug(f'Invocation cache miss for type "{self.get_type()}": {self.id}') output = self.invoke(context) - context.services.invocation_cache.save(key, output) + services.invocation_cache.save(key, output) return output else: - context.services.logger.debug(f'Invocation cache hit for type "{self.get_type()}": {self.id}') + services.logger.debug(f'Invocation cache hit for type "{self.get_type()}": {self.id}') return cached_value else: - context.services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}') + services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}') return self.invoke(context) id: str = Field( @@ -513,3 +519,29 @@ def wrapper(cls: Type[TBaseInvocationOutput]) -> Type[TBaseInvocationOutput]: return cls return wrapper + + +class WithMetadata(BaseModel): + """ + Inherit from this class if your node needs a metadata input field. + """ + + metadata: Optional[MetadataField] = Field( + default=None, + description=FieldDescriptions.metadata, + json_schema_extra=InputFieldJSONSchemaExtra( + field_kind=FieldKind.Internal, + input=Input.Connection, + orig_required=False, + ).model_dump(exclude_none=True), + ) + + +class WithWorkflow: + workflow = None + + def __init_subclass__(cls) -> None: + logger.warn( + f"{cls.__module__.split('.')[0]}.{cls.__name__}: WithWorkflow is deprecated. Use `context.workflow` to access the workflow." + ) + super().__init_subclass__() diff --git a/invokeai/app/invocations/collections.py b/invokeai/app/invocations/collections.py index d35a9d79c74..f5709b4ba36 100644 --- a/invokeai/app/invocations/collections.py +++ b/invokeai/app/invocations/collections.py @@ -7,7 +7,7 @@ from invokeai.app.invocations.primitives import IntegerCollectionOutput from invokeai.app.util.misc import SEED_MAX -from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .baseinvocation import BaseInvocation, invocation from .fields import InputField @@ -27,7 +27,7 @@ def stop_gt_start(cls, v: int, info: ValidationInfo): raise ValueError("stop must be greater than start") return v - def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: + def invoke(self, context) -> IntegerCollectionOutput: return IntegerCollectionOutput(collection=list(range(self.start, self.stop, self.step))) @@ -45,7 +45,7 @@ class RangeOfSizeInvocation(BaseInvocation): size: int = InputField(default=1, gt=0, description="The number of values") step: int = InputField(default=1, description="The step of the range") - def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: + def invoke(self, context) -> IntegerCollectionOutput: return IntegerCollectionOutput( collection=list(range(self.start, self.start + (self.step * self.size), self.step)) ) @@ -72,6 +72,6 @@ class RandomRangeInvocation(BaseInvocation): description="The seed for the RNG (omit for random)", ) - def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: + def invoke(self, context) -> IntegerCollectionOutput: rng = np.random.default_rng(self.seed) return IntegerCollectionOutput(collection=list(rng.integers(low=self.low, high=self.high, size=self.size))) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index b386aef2cbe..b4496031bc4 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -1,12 +1,18 @@ -from dataclasses import dataclass -from typing import List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Union import torch from compel import Compel, ReturnedEmbeddingsType from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment -from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent -from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput +from invokeai.app.invocations.fields import ( + ConditioningFieldData, + FieldDescriptions, + Input, + InputField, + OutputField, + UIComponent, +) +from invokeai.app.invocations.primitives import ConditioningOutput from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, ExtraConditioningInfo, @@ -20,16 +26,14 @@ from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) from .model import ClipField +if TYPE_CHECKING: + from invokeai.app.services.shared.invocation_context import InvocationContext -@dataclass -class ConditioningFieldData: - conditionings: List[BasicConditioningInfo] # unconditioned: Optional[torch.Tensor] @@ -44,7 +48,7 @@ class ConditioningFieldData: title="Prompt", tags=["prompt", "compel"], category="conditioning", - version="1.0.0", + version="1.0.1", ) class CompelInvocation(BaseInvocation): """Parse prompt using compel package to conditioning.""" @@ -61,26 +65,18 @@ class CompelInvocation(BaseInvocation): ) @torch.no_grad() - def invoke(self, context: InvocationContext) -> ConditioningOutput: - tokenizer_info = context.services.model_manager.get_model( - **self.clip.tokenizer.model_dump(), - context=context, - ) - text_encoder_info = context.services.model_manager.get_model( - **self.clip.text_encoder.model_dump(), - context=context, - ) + def invoke(self, context) -> ConditioningOutput: + tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump()) + text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump()) def _lora_loader(): for lora in self.clip.loras: - lora_info = context.services.model_manager.get_model( - **lora.model_dump(exclude={"weight"}), context=context - ) + lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) yield (lora_info.context.model, lora.weight) del lora_info return - # loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] + # loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] ti_list = [] for trigger in extract_ti_triggers_from_prompt(self.prompt): @@ -89,11 +85,10 @@ def _lora_loader(): ti_list.append( ( name, - context.services.model_manager.get_model( + context.models.load( model_name=name, base_model=self.clip.text_encoder.base_model, model_type=ModelType.TextualInversion, - context=context, ).context.model, ) ) @@ -124,7 +119,7 @@ def _lora_loader(): conjunction = Compel.parse_prompt_string(self.prompt) - if context.services.configuration.log_tokenization: + if context.config.get().log_tokenization: log_tokenization_for_conjunction(conjunction, tokenizer) c, options = compel.build_conditioning_tensor_for_conjunction(conjunction) @@ -145,34 +140,23 @@ def _lora_loader(): ] ) - conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" - context.services.latents.save(conditioning_name, conditioning_data) + conditioning_name = context.conditioning.save(conditioning_data) - return ConditioningOutput( - conditioning=ConditioningField( - conditioning_name=conditioning_name, - ), - ) + return ConditioningOutput.build(conditioning_name) class SDXLPromptInvocationBase: def run_clip_compel( self, - context: InvocationContext, + context: "InvocationContext", clip_field: ClipField, prompt: str, get_pooled: bool, lora_prefix: str, zero_on_empty: bool, ): - tokenizer_info = context.services.model_manager.get_model( - **clip_field.tokenizer.model_dump(), - context=context, - ) - text_encoder_info = context.services.model_manager.get_model( - **clip_field.text_encoder.model_dump(), - context=context, - ) + tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump()) + text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump()) # return zero on empty if prompt == "" and zero_on_empty: @@ -196,14 +180,12 @@ def run_clip_compel( def _lora_loader(): for lora in clip_field.loras: - lora_info = context.services.model_manager.get_model( - **lora.model_dump(exclude={"weight"}), context=context - ) + lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) yield (lora_info.context.model, lora.weight) del lora_info return - # loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] + # loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] ti_list = [] for trigger in extract_ti_triggers_from_prompt(prompt): @@ -212,11 +194,10 @@ def _lora_loader(): ti_list.append( ( name, - context.services.model_manager.get_model( + context.models.load( model_name=name, base_model=clip_field.text_encoder.base_model, model_type=ModelType.TextualInversion, - context=context, ).context.model, ) ) @@ -249,7 +230,7 @@ def _lora_loader(): conjunction = Compel.parse_prompt_string(prompt) - if context.services.configuration.log_tokenization: + if context.config.get().log_tokenization: # TODO: better logging for and syntax log_tokenization_for_conjunction(conjunction, tokenizer) @@ -282,7 +263,7 @@ def _lora_loader(): title="SDXL Prompt", tags=["sdxl", "compel", "prompt"], category="conditioning", - version="1.0.0", + version="1.0.1", ) class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): """Parse prompt using compel package to conditioning.""" @@ -307,7 +288,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2") @torch.no_grad() - def invoke(self, context: InvocationContext) -> ConditioningOutput: + def invoke(self, context) -> ConditioningOutput: c1, c1_pooled, ec1 = self.run_clip_compel( context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True ) @@ -364,14 +345,9 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput: ] ) - conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" - context.services.latents.save(conditioning_name, conditioning_data) + conditioning_name = context.conditioning.save(conditioning_data) - return ConditioningOutput( - conditioning=ConditioningField( - conditioning_name=conditioning_name, - ), - ) + return ConditioningOutput.build(conditioning_name) @invocation( @@ -379,7 +355,7 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput: title="SDXL Refiner Prompt", tags=["sdxl", "compel", "prompt"], category="conditioning", - version="1.0.0", + version="1.0.1", ) class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): """Parse prompt using compel package to conditioning.""" @@ -397,7 +373,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection) @torch.no_grad() - def invoke(self, context: InvocationContext) -> ConditioningOutput: + def invoke(self, context) -> ConditioningOutput: # TODO: if there will appear lora for refiner - write proper prefix c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "", zero_on_empty=False) @@ -417,14 +393,9 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput: ] ) - conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" - context.services.latents.save(conditioning_name, conditioning_data) + conditioning_name = context.conditioning.save(conditioning_data) - return ConditioningOutput( - conditioning=ConditioningField( - conditioning_name=conditioning_name, - ), - ) + return ConditioningOutput.build(conditioning_name) @invocation_output("clip_skip_output") @@ -447,7 +418,7 @@ class ClipSkipInvocation(BaseInvocation): clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP") skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers) - def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput: + def invoke(self, context) -> ClipSkipInvocationOutput: self.clip.skipped_layers += self.skipped_layers return ClipSkipInvocationOutput( clip=self.clip, diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 9b652b8eee9..3797722c93e 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -25,18 +25,17 @@ from PIL import Image from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator -from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, WithMetadata -from invokeai.app.invocations.primitives import ImageField, ImageOutput +from invokeai.app.invocations.baseinvocation import WithMetadata +from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField +from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.util import validate_begin_end_step, validate_weights -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin from invokeai.backend.image_util.depth_anything import DepthAnythingDetector from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector +from invokeai.backend.model_management.models.base import BaseModelType -from ...backend.model_management import BaseModelType from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) @@ -121,7 +120,7 @@ def validate_begin_end_step_percent(self) -> "ControlNetInvocation": validate_begin_end_step(self.begin_step_percent, self.end_step_percent) return self - def invoke(self, context: InvocationContext) -> ControlOutput: + def invoke(self, context) -> ControlOutput: return ControlOutput( control=ControlField( image=self.image, @@ -145,23 +144,14 @@ def run_processor(self, image: Image.Image) -> Image.Image: # superclass just passes through image without processing return image - def invoke(self, context: InvocationContext) -> ImageOutput: - raw_image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + raw_image = context.images.get_pil(self.image.image_name) # image type should be PIL.PngImagePlugin.PngImageFile ? processed_image = self.run_processor(raw_image) # currently can't see processed image in node UI without a showImage node, # so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery - image_dto = context.services.images.create( - image=processed_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.CONTROL, - session_id=context.graph_execution_state_id, - node_id=self.id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=processed_image) """Builds an ImageOutput and its ImageField""" processed_image_field = ImageField(image_name=image_dto.image_name) @@ -180,7 +170,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Canny Processor", tags=["controlnet", "canny"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class CannyImageProcessorInvocation(ImageProcessorInvocation): """Canny edge detection for ControlNet""" @@ -203,7 +193,7 @@ def run_processor(self, image): title="HED (softedge) Processor", tags=["controlnet", "hed", "softedge"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class HedImageProcessorInvocation(ImageProcessorInvocation): """Applies HED edge detection to image""" @@ -232,7 +222,7 @@ def run_processor(self, image): title="Lineart Processor", tags=["controlnet", "lineart"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class LineartImageProcessorInvocation(ImageProcessorInvocation): """Applies line art processing to image""" @@ -254,7 +244,7 @@ def run_processor(self, image): title="Lineart Anime Processor", tags=["controlnet", "lineart", "anime"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation): """Applies line art anime processing to image""" @@ -277,7 +267,7 @@ def run_processor(self, image): title="Midas Depth Processor", tags=["controlnet", "midas"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class MidasDepthImageProcessorInvocation(ImageProcessorInvocation): """Applies Midas depth processing to image""" @@ -304,7 +294,7 @@ def run_processor(self, image): title="Normal BAE Processor", tags=["controlnet"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class NormalbaeImageProcessorInvocation(ImageProcessorInvocation): """Applies NormalBae processing to image""" @@ -321,7 +311,7 @@ def run_processor(self, image): @invocation( - "mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.2.0" + "mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.2.1" ) class MlsdImageProcessorInvocation(ImageProcessorInvocation): """Applies MLSD processing to image""" @@ -344,7 +334,7 @@ def run_processor(self, image): @invocation( - "pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.2.0" + "pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.2.1" ) class PidiImageProcessorInvocation(ImageProcessorInvocation): """Applies PIDI processing to image""" @@ -371,7 +361,7 @@ def run_processor(self, image): title="Content Shuffle Processor", tags=["controlnet", "contentshuffle"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation): """Applies content shuffle processing to image""" @@ -401,7 +391,7 @@ def run_processor(self, image): title="Zoe (Depth) Processor", tags=["controlnet", "zoe", "depth"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation): """Applies Zoe depth processing to image""" @@ -417,7 +407,7 @@ def run_processor(self, image): title="Mediapipe Face Processor", tags=["controlnet", "mediapipe", "face"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class MediapipeFaceProcessorInvocation(ImageProcessorInvocation): """Applies mediapipe face processing to image""" @@ -440,7 +430,7 @@ def run_processor(self, image): title="Leres (Depth) Processor", tags=["controlnet", "leres", "depth"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class LeresImageProcessorInvocation(ImageProcessorInvocation): """Applies leres processing to image""" @@ -469,7 +459,7 @@ def run_processor(self, image): title="Tile Resample Processor", tags=["controlnet", "tile"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class TileResamplerProcessorInvocation(ImageProcessorInvocation): """Tile resampler processor""" @@ -509,7 +499,7 @@ def run_processor(self, img): title="Segment Anything Processor", tags=["controlnet", "segmentanything"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class SegmentAnythingProcessorInvocation(ImageProcessorInvocation): """Applies segment anything processing to image""" @@ -551,7 +541,7 @@ def show_anns(self, anns: List[Dict]): title="Color Map Processor", tags=["controlnet"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class ColorMapImageProcessorInvocation(ImageProcessorInvocation): """Generates a color map from the provided image""" diff --git a/invokeai/app/invocations/cv.py b/invokeai/app/invocations/cv.py index 5865338e192..375b18f9c58 100644 --- a/invokeai/app/invocations/cv.py +++ b/invokeai/app/invocations/cv.py @@ -5,23 +5,23 @@ import numpy from PIL import Image, ImageOps -from invokeai.app.invocations.primitives import ImageField, ImageOutput -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin +from invokeai.app.invocations.fields import ImageField +from invokeai.app.invocations.primitives import ImageOutput -from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .baseinvocation import BaseInvocation, invocation from .fields import InputField, WithMetadata -@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.2.0") +@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.2.1") class CvInpaintInvocation(BaseInvocation, WithMetadata): """Simple inpaint using opencv.""" image: ImageField = InputField(description="The image to inpaint") mask: ImageField = InputField(description="The mask to use when inpainting") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) - mask = context.services.images.get_pil_image(self.mask.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) + mask = context.images.get_pil(self.mask.image_name) # Convert to cv image/mask # TODO: consider making these utility functions @@ -35,18 +35,6 @@ def invoke(self, context: InvocationContext) -> ImageOutput: # TODO: consider making a utility function image_inpainted = Image.fromarray(cv.cvtColor(cv_inpainted, cv.COLOR_BGR2RGB)) - image_dto = context.services.images.create( - image=image_inpainted, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - workflow=context.workflow, - ) - - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + image_dto = context.images.save(image=image_inpainted) + + return ImageOutput.build(image_dto) diff --git a/invokeai/app/invocations/facetools.py b/invokeai/app/invocations/facetools.py index 13f1066ec3e..2c92e28cfe0 100644 --- a/invokeai/app/invocations/facetools.py +++ b/invokeai/app/invocations/facetools.py @@ -1,7 +1,7 @@ import math import re from pathlib import Path -from typing import Optional, TypedDict +from typing import TYPE_CHECKING, Optional, TypedDict import cv2 import numpy as np @@ -13,13 +13,16 @@ import invokeai.assets.fonts as font_assets from invokeai.app.invocations.baseinvocation import ( BaseInvocation, - InvocationContext, + WithMetadata, invocation, invocation_output, ) -from invokeai.app.invocations.fields import InputField, OutputField, WithMetadata -from invokeai.app.invocations.primitives import ImageField, ImageOutput -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin +from invokeai.app.invocations.fields import ImageField, InputField, OutputField +from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.image_records.image_records_common import ImageCategory + +if TYPE_CHECKING: + from invokeai.app.services.shared.invocation_context import InvocationContext @invocation_output("face_mask_output") @@ -174,7 +177,7 @@ def prepare_faces_list( def generate_face_box_mask( - context: InvocationContext, + context: "InvocationContext", minimum_confidence: float, x_offset: float, y_offset: float, @@ -273,7 +276,7 @@ def generate_face_box_mask( def extract_face( - context: InvocationContext, + context: "InvocationContext", image: ImageType, face: FaceResultData, padding: int, @@ -304,37 +307,37 @@ def extract_face( # Adjust the crop boundaries to stay within the original image's dimensions if x_min < 0: - context.services.logger.warning("FaceTools --> -X-axis padding reached image edge.") + context.logger.warning("FaceTools --> -X-axis padding reached image edge.") x_max -= x_min x_min = 0 elif x_max > mask.width: - context.services.logger.warning("FaceTools --> +X-axis padding reached image edge.") + context.logger.warning("FaceTools --> +X-axis padding reached image edge.") x_min -= x_max - mask.width x_max = mask.width if y_min < 0: - context.services.logger.warning("FaceTools --> +Y-axis padding reached image edge.") + context.logger.warning("FaceTools --> +Y-axis padding reached image edge.") y_max -= y_min y_min = 0 elif y_max > mask.height: - context.services.logger.warning("FaceTools --> -Y-axis padding reached image edge.") + context.logger.warning("FaceTools --> -Y-axis padding reached image edge.") y_min -= y_max - mask.height y_max = mask.height # Ensure the crop is square and adjust the boundaries if needed if x_max - x_min != crop_size: - context.services.logger.warning("FaceTools --> Limiting x-axis padding to constrain bounding box to a square.") + context.logger.warning("FaceTools --> Limiting x-axis padding to constrain bounding box to a square.") diff = crop_size - (x_max - x_min) x_min -= diff // 2 x_max += diff - diff // 2 if y_max - y_min != crop_size: - context.services.logger.warning("FaceTools --> Limiting y-axis padding to constrain bounding box to a square.") + context.logger.warning("FaceTools --> Limiting y-axis padding to constrain bounding box to a square.") diff = crop_size - (y_max - y_min) y_min -= diff // 2 y_max += diff - diff // 2 - context.services.logger.info(f"FaceTools --> Calculated bounding box (8 multiple): {crop_size}") + context.logger.info(f"FaceTools --> Calculated bounding box (8 multiple): {crop_size}") # Crop the output image to the specified size with the center of the face mesh as the center. mask = mask.crop((x_min, y_min, x_max, y_max)) @@ -354,7 +357,7 @@ def extract_face( def get_faces_list( - context: InvocationContext, + context: "InvocationContext", image: ImageType, should_chunk: bool, minimum_confidence: float, @@ -366,7 +369,7 @@ def get_faces_list( # Generate the face box mask and get the center of the face. if not should_chunk: - context.services.logger.info("FaceTools --> Attempting full image face detection.") + context.logger.info("FaceTools --> Attempting full image face detection.") result = generate_face_box_mask( context=context, minimum_confidence=minimum_confidence, @@ -378,7 +381,7 @@ def get_faces_list( draw_mesh=draw_mesh, ) if should_chunk or len(result) == 0: - context.services.logger.info("FaceTools --> Chunking image (chunk toggled on, or no face found in full image).") + context.logger.info("FaceTools --> Chunking image (chunk toggled on, or no face found in full image).") width, height = image.size image_chunks = [] x_offsets = [] @@ -397,7 +400,7 @@ def get_faces_list( x_offsets.append(x) y_offsets.append(0) fx += increment - context.services.logger.info(f"FaceTools --> Chunk starting at x = {x}") + context.logger.info(f"FaceTools --> Chunk starting at x = {x}") elif height > width: # Portrait - slice the image vertically fy = 0.0 @@ -409,10 +412,10 @@ def get_faces_list( x_offsets.append(0) y_offsets.append(y) fy += increment - context.services.logger.info(f"FaceTools --> Chunk starting at y = {y}") + context.logger.info(f"FaceTools --> Chunk starting at y = {y}") for idx in range(len(image_chunks)): - context.services.logger.info(f"FaceTools --> Evaluating faces in chunk {idx}") + context.logger.info(f"FaceTools --> Evaluating faces in chunk {idx}") result = result + generate_face_box_mask( context=context, minimum_confidence=minimum_confidence, @@ -426,7 +429,7 @@ def get_faces_list( if len(result) == 0: # Give up - context.services.logger.warning( + context.logger.warning( "FaceTools --> No face detected in chunked input image. Passing through original image." ) @@ -435,7 +438,7 @@ def get_faces_list( return all_faces -@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.2.0") +@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.2.1") class FaceOffInvocation(BaseInvocation, WithMetadata): """Bound, extract, and mask a face from an image using MediaPipe detection""" @@ -456,7 +459,7 @@ class FaceOffInvocation(BaseInvocation, WithMetadata): description="Whether to bypass full image face detection and default to image chunking. Chunking will occur if no faces are found in the full image.", ) - def faceoff(self, context: InvocationContext, image: ImageType) -> Optional[ExtractFaceData]: + def faceoff(self, context: "InvocationContext", image: ImageType) -> Optional[ExtractFaceData]: all_faces = get_faces_list( context=context, image=image, @@ -468,11 +471,11 @@ def faceoff(self, context: InvocationContext, image: ImageType) -> Optional[Extr ) if len(all_faces) == 0: - context.services.logger.warning("FaceOff --> No faces detected. Passing through original image.") + context.logger.warning("FaceOff --> No faces detected. Passing through original image.") return None if self.face_id > len(all_faces) - 1: - context.services.logger.warning( + context.logger.warning( f"FaceOff --> Face ID {self.face_id} is outside of the number of faces detected ({len(all_faces)}). Passing through original image." ) return None @@ -483,8 +486,8 @@ def faceoff(self, context: InvocationContext, image: ImageType) -> Optional[Extr return face_data - def invoke(self, context: InvocationContext) -> FaceOffOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> FaceOffOutput: + image = context.images.get_pil(self.image.image_name) result = self.faceoff(context=context, image=image) if result is None: @@ -498,24 +501,9 @@ def invoke(self, context: InvocationContext) -> FaceOffOutput: x = result["x_min"] y = result["y_min"] - image_dto = context.services.images.create( - image=result_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - workflow=context.workflow, - ) + image_dto = context.images.save(image=result_image) - mask_dto = context.services.images.create( - image=result_mask, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.MASK, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - ) + mask_dto = context.images.save(image=result_mask, image_category=ImageCategory.MASK) output = FaceOffOutput( image=ImageField(image_name=image_dto.image_name), @@ -529,7 +517,7 @@ def invoke(self, context: InvocationContext) -> FaceOffOutput: return output -@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.2.0") +@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.2.1") class FaceMaskInvocation(BaseInvocation, WithMetadata): """Face mask creation using mediapipe face detection""" @@ -556,7 +544,7 @@ def validate_comma_separated_ints(cls, v) -> str: raise ValueError('Face IDs must be a comma-separated list of integers (e.g. "1,2,3")') return v - def facemask(self, context: InvocationContext, image: ImageType) -> FaceMaskResult: + def facemask(self, context: "InvocationContext", image: ImageType) -> FaceMaskResult: all_faces = get_faces_list( context=context, image=image, @@ -578,7 +566,7 @@ def facemask(self, context: InvocationContext, image: ImageType) -> FaceMaskResu if len(intersected_face_ids) == 0: id_range_str = ",".join([str(id) for id in id_range]) - context.services.logger.warning( + context.logger.warning( f"Face IDs must be in range of detected faces - requested {self.face_ids}, detected {id_range_str}. Passing through original image." ) return FaceMaskResult( @@ -613,28 +601,13 @@ def facemask(self, context: InvocationContext, image: ImageType) -> FaceMaskResu mask=mask_pil, ) - def invoke(self, context: InvocationContext) -> FaceMaskOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> FaceMaskOutput: + image = context.images.get_pil(self.image.image_name) result = self.facemask(context=context, image=image) - image_dto = context.services.images.create( - image=result["image"], - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - workflow=context.workflow, - ) + image_dto = context.images.save(image=result["image"]) - mask_dto = context.services.images.create( - image=result["mask"], - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.MASK, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - ) + mask_dto = context.images.save(image=result["mask"], image_category=ImageCategory.MASK) output = FaceMaskOutput( image=ImageField(image_name=image_dto.image_name), @@ -647,7 +620,7 @@ def invoke(self, context: InvocationContext) -> FaceMaskOutput: @invocation( - "face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.2.0" + "face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.2.1" ) class FaceIdentifierInvocation(BaseInvocation, WithMetadata): """Outputs an image with detected face IDs printed on each face. For use with other FaceTools.""" @@ -661,7 +634,7 @@ class FaceIdentifierInvocation(BaseInvocation, WithMetadata): description="Whether to bypass full image face detection and default to image chunking. Chunking will occur if no faces are found in the full image.", ) - def faceidentifier(self, context: InvocationContext, image: ImageType) -> ImageType: + def faceidentifier(self, context: "InvocationContext", image: ImageType) -> ImageType: image = image.copy() all_faces = get_faces_list( @@ -702,22 +675,10 @@ def faceidentifier(self, context: InvocationContext, image: ImageType) -> ImageT return image - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) result_image = self.faceidentifier(context=context, image=image) - image_dto = context.services.images.create( - image=result_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - workflow=context.workflow, - ) + image_dto = context.images.save(image=result_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index 0cce8e3c6b5..566babbb6b7 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -1,11 +1,13 @@ +from dataclasses import dataclass from enum import Enum -from typing import Any, Callable, Optional +from typing import Any, Callable, List, Optional, Tuple from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter from pydantic.fields import _Unset from pydantic_core import PydanticUndefined from invokeai.app.util.metaenum import MetaEnum +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import BasicConditioningInfo from invokeai.backend.util.logging import InvokeAILogger logger = InvokeAILogger.get_logger() @@ -255,6 +257,10 @@ class InputFieldJSONSchemaExtra(BaseModel): class WithMetadata(BaseModel): + """ + Inherit from this class if your node needs a metadata input field. + """ + metadata: Optional[MetadataField] = Field( default=None, description=FieldDescriptions.metadata, @@ -498,4 +504,53 @@ def OutputField( field_kind=FieldKind.Output, ).model_dump(exclude_none=True), ) + + +class ImageField(BaseModel): + """An image primitive field""" + + image_name: str = Field(description="The name of the image") + + +class BoardField(BaseModel): + """A board primitive field""" + + board_id: str = Field(description="The id of the board") + + +class DenoiseMaskField(BaseModel): + """An inpaint mask field""" + + mask_name: str = Field(description="The name of the mask image") + masked_latents_name: Optional[str] = Field(default=None, description="The name of the masked image latents") + + +class LatentsField(BaseModel): + """A latents tensor primitive field""" + + latents_name: str = Field(description="The name of the latents") + seed: Optional[int] = Field(default=None, description="Seed used to generate this latents") + + +class ColorField(BaseModel): + """A color primitive field""" + + r: int = Field(ge=0, le=255, description="The red component") + g: int = Field(ge=0, le=255, description="The green component") + b: int = Field(ge=0, le=255, description="The blue component") + a: int = Field(ge=0, le=255, description="The alpha component") + + def tuple(self) -> Tuple[int, int, int, int]: + return (self.r, self.g, self.b, self.a) + + +@dataclass +class ConditioningFieldData: + conditionings: List[BasicConditioningInfo] + + +class ConditioningField(BaseModel): + """A conditioning tensor primitive value""" + + conditioning_name: str = Field(description="The name of conditioning tensor") # endregion diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 16d0f33dda3..10ebd97ace3 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -7,30 +7,36 @@ import numpy from PIL import Image, ImageChops, ImageFilter, ImageOps -from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, WithMetadata -from invokeai.app.invocations.primitives import BoardField, ColorField, ImageField, ImageOutput -from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin +from invokeai.app.invocations.baseinvocation import WithMetadata +from invokeai.app.invocations.fields import ( + BoardField, + ColorField, + FieldDescriptions, + ImageField, + Input, + InputField, +) +from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.image_records.image_records_common import ImageCategory from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark from invokeai.backend.image_util.safety_checker import SafetyChecker from .baseinvocation import ( BaseInvocation, Classification, - InvocationContext, invocation, ) -@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.0") +@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.1") class ShowImageInvocation(BaseInvocation): """Displays a provided image using the OS image viewer, and passes it forward in the pipeline.""" image: ImageField = InputField(description="The image to show") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) - if image: - image.show() + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) + image.show() # TODO: how to handle failure? @@ -46,7 +52,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Blank Image", tags=["image"], category="image", - version="1.2.0", + version="1.2.1", ) class BlankImageInvocation(BaseInvocation, WithMetadata): """Creates a blank image and forwards it to the pipeline""" @@ -56,25 +62,12 @@ class BlankImageInvocation(BaseInvocation, WithMetadata): mode: Literal["RGB", "RGBA"] = InputField(default="RGB", description="The mode of the image") color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color of the image") - def invoke(self, context: InvocationContext) -> ImageOutput: + def invoke(self, context) -> ImageOutput: image = Image.new(mode=self.mode, size=(self.width, self.height), color=self.color.tuple()) - image_dto = context.services.images.create( - image=image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -82,7 +75,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Crop Image", tags=["image", "crop"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageCropInvocation(BaseInvocation, WithMetadata): """Crops an image to a specified box. The box can be outside of the image.""" @@ -93,28 +86,15 @@ class ImageCropInvocation(BaseInvocation, WithMetadata): width: int = InputField(default=512, gt=0, description="The width of the crop rectangle") height: int = InputField(default=512, gt=0, description="The height of the crop rectangle") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) image_crop = Image.new(mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0)) image_crop.paste(image, (-self.x, -self.y)) - image_dto = context.services.images.create( - image=image_crop, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=image_crop) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -145,8 +125,8 @@ class CenterPadCropInvocation(BaseInvocation): description="Number of pixels to pad/crop from the bottom (negative values crop inwards, positive values pad outwards)", ) - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) # Calculate and create new image dimensions new_width = image.width + self.right + self.left @@ -156,20 +136,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: # Paste new image onto input image_crop.paste(image, (self.left, self.top)) - image_dto = context.services.images.create( - image=image_crop, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - ) + image_dto = context.images.save(image=image_crop) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -177,7 +146,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Paste Image", tags=["image", "paste"], category="image", - version="1.2.0", + version="1.2.1", ) class ImagePasteInvocation(BaseInvocation, WithMetadata): """Pastes an image into another image.""" @@ -192,12 +161,12 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata): y: int = InputField(default=0, description="The top y coordinate at which to paste the image") crop: bool = InputField(default=False, description="Crop to base image dimensions") - def invoke(self, context: InvocationContext) -> ImageOutput: - base_image = context.services.images.get_pil_image(self.base_image.image_name) - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + base_image = context.images.get_pil(self.base_image.image_name) + image = context.images.get_pil(self.image.image_name) mask = None if self.mask is not None: - mask = context.services.images.get_pil_image(self.mask.image_name) + mask = context.images.get_pil(self.mask.image_name) mask = ImageOps.invert(mask.convert("L")) # TODO: probably shouldn't invert mask here... should user be required to do it? @@ -214,22 +183,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: base_w, base_h = base_image.size new_image = new_image.crop((abs(min_x), abs(min_y), abs(min_x) + base_w, abs(min_y) + base_h)) - image_dto = context.services.images.create( - image=new_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=new_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -237,7 +193,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Mask from Alpha", tags=["image", "mask"], category="image", - version="1.2.0", + version="1.2.1", ) class MaskFromAlphaInvocation(BaseInvocation, WithMetadata): """Extracts the alpha channel of an image as a mask.""" @@ -245,29 +201,16 @@ class MaskFromAlphaInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to create the mask from") invert: bool = InputField(default=False, description="Whether or not to invert the mask") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) image_mask = image.split()[-1] if self.invert: image_mask = ImageOps.invert(image_mask) - image_dto = context.services.images.create( - image=image_mask, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.MASK, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=image_mask, image_category=ImageCategory.MASK) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -275,7 +218,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Multiply Images", tags=["image", "multiply"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageMultiplyInvocation(BaseInvocation, WithMetadata): """Multiplies two images together using `PIL.ImageChops.multiply()`.""" @@ -283,28 +226,15 @@ class ImageMultiplyInvocation(BaseInvocation, WithMetadata): image1: ImageField = InputField(description="The first image to multiply") image2: ImageField = InputField(description="The second image to multiply") - def invoke(self, context: InvocationContext) -> ImageOutput: - image1 = context.services.images.get_pil_image(self.image1.image_name) - image2 = context.services.images.get_pil_image(self.image2.image_name) + def invoke(self, context) -> ImageOutput: + image1 = context.images.get_pil(self.image1.image_name) + image2 = context.images.get_pil(self.image2.image_name) multiply_image = ImageChops.multiply(image1, image2) - image_dto = context.services.images.create( - image=multiply_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=multiply_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) IMAGE_CHANNELS = Literal["A", "R", "G", "B"] @@ -315,7 +245,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Extract Image Channel", tags=["image", "channel"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageChannelInvocation(BaseInvocation, WithMetadata): """Gets a channel from an image.""" @@ -323,27 +253,14 @@ class ImageChannelInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to get the channel from") channel: IMAGE_CHANNELS = InputField(default="A", description="The channel to get") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) channel_image = image.getchannel(self.channel) - image_dto = context.services.images.create( - image=channel_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=channel_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"] @@ -354,7 +271,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Convert Image Mode", tags=["image", "convert"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageConvertInvocation(BaseInvocation, WithMetadata): """Converts an image to a different mode.""" @@ -362,27 +279,14 @@ class ImageConvertInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to convert") mode: IMAGE_MODES = InputField(default="L", description="The mode to convert to") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) converted_image = image.convert(self.mode) - image_dto = context.services.images.create( - image=converted_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=converted_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -390,7 +294,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Blur Image", tags=["image", "blur"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageBlurInvocation(BaseInvocation, WithMetadata): """Blurs an image""" @@ -400,30 +304,17 @@ class ImageBlurInvocation(BaseInvocation, WithMetadata): # Metadata blur_type: Literal["gaussian", "box"] = InputField(default="gaussian", description="The type of blur") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) blur = ( ImageFilter.GaussianBlur(self.radius) if self.blur_type == "gaussian" else ImageFilter.BoxBlur(self.radius) ) blur_image = image.filter(blur) - image_dto = context.services.images.create( - image=blur_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=blur_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -431,7 +322,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Unsharp Mask", tags=["image", "unsharp_mask"], category="image", - version="1.2.0", + version="1.2.1", classification=Classification.Beta, ) class UnsharpMaskInvocation(BaseInvocation, WithMetadata): @@ -447,8 +338,8 @@ def pil_from_array(self, arr): def array_from_pil(self, img): return numpy.array(img) / 255 - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) mode = image.mode alpha_channel = image.getchannel("A") if mode == "RGBA" else None @@ -466,16 +357,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: if alpha_channel is not None: image.putalpha(alpha_channel) - image_dto = context.services.images.create( - image=image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=image) return ImageOutput( image=ImageField(image_name=image_dto.image_name), @@ -509,7 +391,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Resize Image", tags=["image", "resize"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageResizeInvocation(BaseInvocation, WithMetadata): """Resizes an image to specific dimensions""" @@ -519,8 +401,8 @@ class ImageResizeInvocation(BaseInvocation, WithMetadata): height: int = InputField(default=512, gt=0, description="The height to resize to (px)") resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] @@ -529,22 +411,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: resample=resample_mode, ) - image_dto = context.services.images.create( - image=resize_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=resize_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -552,7 +421,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Scale Image", tags=["image", "scale"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageScaleInvocation(BaseInvocation, WithMetadata): """Scales an image by a factor""" @@ -565,8 +434,8 @@ class ImageScaleInvocation(BaseInvocation, WithMetadata): ) resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] width = int(image.width * self.scale_factor) @@ -577,22 +446,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: resample=resample_mode, ) - image_dto = context.services.images.create( - image=resize_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=resize_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -600,7 +456,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Lerp Image", tags=["image", "lerp"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageLerpInvocation(BaseInvocation, WithMetadata): """Linear interpolation of all pixels of an image""" @@ -609,30 +465,17 @@ class ImageLerpInvocation(BaseInvocation, WithMetadata): min: int = InputField(default=0, ge=0, le=255, description="The minimum output value") max: int = InputField(default=255, ge=0, le=255, description="The maximum output value") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) image_arr = numpy.asarray(image, dtype=numpy.float32) / 255 image_arr = image_arr * (self.max - self.min) + self.min lerp_image = Image.fromarray(numpy.uint8(image_arr)) - image_dto = context.services.images.create( - image=lerp_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=lerp_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -640,7 +483,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Inverse Lerp Image", tags=["image", "ilerp"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageInverseLerpInvocation(BaseInvocation, WithMetadata): """Inverse linear interpolation of all pixels of an image""" @@ -649,30 +492,17 @@ class ImageInverseLerpInvocation(BaseInvocation, WithMetadata): min: int = InputField(default=0, ge=0, le=255, description="The minimum input value") max: int = InputField(default=255, ge=0, le=255, description="The maximum input value") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) image_arr = numpy.asarray(image, dtype=numpy.float32) image_arr = numpy.minimum(numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1) * 255 # type: ignore [assignment] ilerp_image = Image.fromarray(numpy.uint8(image_arr)) - image_dto = context.services.images.create( - image=ilerp_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=ilerp_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -680,17 +510,17 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Blur NSFW Image", tags=["image", "nsfw"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata): """Add blur to NSFW-flagged images""" image: ImageField = InputField(description="The image to check") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) - logger = context.services.logger + logger = context.logger logger.debug("Running NSFW checker") if SafetyChecker.has_nsfw_concept(image): logger.info("A potentially NSFW image has been detected. Image will be blurred.") @@ -699,22 +529,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: blurry_image.paste(caution, (0, 0), caution) image = blurry_image - image_dto = context.services.images.create( - image=image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) def _get_caution_img(self) -> Image.Image: import invokeai.app.assets.images as image_assets @@ -728,7 +545,7 @@ def _get_caution_img(self) -> Image.Image: title="Add Invisible Watermark", tags=["image", "watermark"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageWatermarkInvocation(BaseInvocation, WithMetadata): """Add an invisible watermark to an image""" @@ -736,25 +553,12 @@ class ImageWatermarkInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to check") text: str = InputField(default="InvokeAI", description="Watermark text") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) new_image = InvisibleWatermark.add_watermark(image, self.text) - image_dto = context.services.images.create( - image=new_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=new_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -762,7 +566,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Mask Edge", tags=["image", "mask", "inpaint"], category="image", - version="1.2.0", + version="1.2.1", ) class MaskEdgeInvocation(BaseInvocation, WithMetadata): """Applies an edge mask to an image""" @@ -775,8 +579,8 @@ class MaskEdgeInvocation(BaseInvocation, WithMetadata): description="Second threshold for the hysteresis procedure in Canny edge detection" ) - def invoke(self, context: InvocationContext) -> ImageOutput: - mask = context.services.images.get_pil_image(self.image.image_name).convert("L") + def invoke(self, context) -> ImageOutput: + mask = context.images.get_pil(self.image.image_name).convert("L") npimg = numpy.asarray(mask, dtype=numpy.uint8) npgradient = numpy.uint8(255 * (1.0 - numpy.floor(numpy.abs(0.5 - numpy.float32(npimg) / 255.0) * 2.0))) @@ -791,22 +595,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: new_mask = ImageOps.invert(new_mask) - image_dto = context.services.images.create( - image=new_mask, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.MASK, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=new_mask, image_category=ImageCategory.MASK) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -814,7 +605,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Combine Masks", tags=["image", "mask", "multiply"], category="image", - version="1.2.0", + version="1.2.1", ) class MaskCombineInvocation(BaseInvocation, WithMetadata): """Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`.""" @@ -822,28 +613,15 @@ class MaskCombineInvocation(BaseInvocation, WithMetadata): mask1: ImageField = InputField(description="The first mask to combine") mask2: ImageField = InputField(description="The second image to combine") - def invoke(self, context: InvocationContext) -> ImageOutput: - mask1 = context.services.images.get_pil_image(self.mask1.image_name).convert("L") - mask2 = context.services.images.get_pil_image(self.mask2.image_name).convert("L") + def invoke(self, context) -> ImageOutput: + mask1 = context.images.get_pil(self.mask1.image_name).convert("L") + mask2 = context.images.get_pil(self.mask2.image_name).convert("L") combined_mask = ImageChops.multiply(mask1, mask2) - image_dto = context.services.images.create( - image=combined_mask, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=combined_mask, image_category=ImageCategory.MASK) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -851,7 +629,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Color Correct", tags=["image", "color"], category="image", - version="1.2.0", + version="1.2.1", ) class ColorCorrectInvocation(BaseInvocation, WithMetadata): """ @@ -864,14 +642,14 @@ class ColorCorrectInvocation(BaseInvocation, WithMetadata): mask: Optional[ImageField] = InputField(default=None, description="Mask to use when applying color-correction") mask_blur_radius: float = InputField(default=8, description="Mask blur radius") - def invoke(self, context: InvocationContext) -> ImageOutput: + def invoke(self, context) -> ImageOutput: pil_init_mask = None if self.mask is not None: - pil_init_mask = context.services.images.get_pil_image(self.mask.image_name).convert("L") + pil_init_mask = context.images.get_pil(self.mask.image_name).convert("L") - init_image = context.services.images.get_pil_image(self.reference.image_name) + init_image = context.images.get_pil(self.reference.image_name) - result = context.services.images.get_pil_image(self.image.image_name).convert("RGBA") + result = context.images.get_pil(self.image.image_name).convert("RGBA") # if init_image is None or init_mask is None: # return result @@ -945,22 +723,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: # Paste original on color-corrected generation (using blurred mask) matched_result.paste(init_image, (0, 0), mask=multiplied_blurred_init_mask) - image_dto = context.services.images.create( - image=matched_result, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=matched_result) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -968,7 +733,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Adjust Image Hue", tags=["image", "hue"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageHueAdjustmentInvocation(BaseInvocation, WithMetadata): """Adjusts the Hue of an image.""" @@ -976,8 +741,8 @@ class ImageHueAdjustmentInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to adjust") hue: int = InputField(default=0, description="The degrees by which to rotate the hue, 0-360") - def invoke(self, context: InvocationContext) -> ImageOutput: - pil_image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + pil_image = context.images.get_pil(self.image.image_name) # Convert image to HSV color space hsv_image = numpy.array(pil_image.convert("HSV")) @@ -991,24 +756,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: # Convert back to PIL format and to original color mode pil_image = Image.fromarray(hsv_image, mode="HSV").convert("RGBA") - image_dto = context.services.images.create( - image=pil_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - is_intermediate=self.is_intermediate, - session_id=context.graph_execution_state_id, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=pil_image) - return ImageOutput( - image=ImageField( - image_name=image_dto.image_name, - ), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) COLOR_CHANNELS = Literal[ @@ -1072,7 +822,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: "value", ], category="image", - version="1.2.0", + version="1.2.1", ) class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata): """Add or subtract a value from a specific color channel of an image.""" @@ -1081,8 +831,8 @@ class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata): channel: COLOR_CHANNELS = InputField(description="Which channel to adjust") offset: int = InputField(default=0, ge=-255, le=255, description="The amount to adjust the channel by") - def invoke(self, context: InvocationContext) -> ImageOutput: - pil_image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + pil_image = context.images.get_pil(self.image.image_name) # extract the channel and mode from the input and reference tuple mode = CHANNEL_FORMATS[self.channel][0] @@ -1101,24 +851,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: # Convert back to RGBA format and output pil_image = Image.fromarray(converted_image.astype(numpy.uint8), mode=mode).convert("RGBA") - image_dto = context.services.images.create( - image=pil_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - is_intermediate=self.is_intermediate, - session_id=context.graph_execution_state_id, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=pil_image) - return ImageOutput( - image=ImageField( - image_name=image_dto.image_name, - ), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -1143,7 +878,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: "value", ], category="image", - version="1.2.0", + version="1.2.1", ) class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata): """Scale a specific color channel of an image.""" @@ -1153,8 +888,8 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata): scale: float = InputField(default=1.0, ge=0.0, description="The amount to scale the channel by.") invert_channel: bool = InputField(default=False, description="Invert the channel after scaling") - def invoke(self, context: InvocationContext) -> ImageOutput: - pil_image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + pil_image = context.images.get_pil(self.image.image_name) # extract the channel and mode from the input and reference tuple mode = CHANNEL_FORMATS[self.channel][0] @@ -1177,24 +912,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: # Convert back to RGBA format and output pil_image = Image.fromarray(converted_image.astype(numpy.uint8), mode=mode).convert("RGBA") - image_dto = context.services.images.create( - image=pil_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - is_intermediate=self.is_intermediate, - session_id=context.graph_execution_state_id, - workflow=context.workflow, - metadata=self.metadata, - ) + image_dto = context.images.save(image=pil_image) - return ImageOutput( - image=ImageField( - image_name=image_dto.image_name, - ), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -1202,7 +922,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Save Image", tags=["primitives", "image"], category="primitives", - version="1.2.0", + version="1.2.1", use_cache=False, ) class SaveImageInvocation(BaseInvocation, WithMetadata): @@ -1211,26 +931,12 @@ class SaveImageInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description=FieldDescriptions.image) board: BoardField = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct) - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) - image_dto = context.services.images.create( - image=image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - board_id=self.board.board_id if self.board else None, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=image, board_id=self.board.board_id if self.board else None) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -1238,7 +944,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Linear UI Image Output", tags=["primitives", "image"], category="primitives", - version="1.0.1", + version="1.0.2", use_cache=False, ) class LinearUIOutputInvocation(BaseInvocation, WithMetadata): @@ -1247,19 +953,13 @@ class LinearUIOutputInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description=FieldDescriptions.image) board: Optional[BoardField] = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct) - def invoke(self, context: InvocationContext) -> ImageOutput: - image_dto = context.services.images.get_dto(self.image.image_name) - - if self.board: - context.services.board_images.add_image_to_board(self.board.board_id, self.image.image_name) + def invoke(self, context) -> ImageOutput: + image_dto = context.images.get_dto(self.image.image_name) - if image_dto.is_intermediate != self.is_intermediate: - context.services.images.update( - self.image.image_name, changes=ImageRecordChanges(is_intermediate=self.is_intermediate) - ) - - return ImageOutput( - image=ImageField(image_name=self.image.image_name), - width=image_dto.width, - height=image_dto.height, + image_dto = context.images.update( + image_name=self.image.image_name, + board_id=self.board.board_id if self.board else None, + is_intermediate=self.is_intermediate, ) + + return ImageOutput.build(image_dto) diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index d4d3d5bea44..be51c8312f9 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -6,15 +6,15 @@ import numpy as np from PIL import Image, ImageOps -from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin +from invokeai.app.invocations.fields import ColorField, ImageField +from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.util.misc import SEED_MAX from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint from invokeai.backend.image_util.lama import LaMA from invokeai.backend.image_util.patchmatch import PatchMatch -from .baseinvocation import BaseInvocation, InvocationContext, invocation -from .fields import InputField, WithMetadata +from .baseinvocation import BaseInvocation, WithMetadata, invocation +from .fields import InputField from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES @@ -119,7 +119,7 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int] return si -@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0") +@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1") class InfillColorInvocation(BaseInvocation, WithMetadata): """Infills transparent areas of an image with a solid color""" @@ -129,33 +129,20 @@ class InfillColorInvocation(BaseInvocation, WithMetadata): description="The color to use to infill", ) - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) solid_bg = Image.new("RGBA", image.size, self.color.tuple()) infilled = Image.alpha_composite(solid_bg, image.convert("RGBA")) infilled.paste(image, (0, 0), image.split()[-1]) - image_dto = context.services.images.create( - image=infilled, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=infilled) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) -@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1") +@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2") class InfillTileInvocation(BaseInvocation, WithMetadata): """Infills transparent areas of an image with tiles of the image""" @@ -168,32 +155,19 @@ class InfillTileInvocation(BaseInvocation, WithMetadata): description="The seed to use for tile generation (omit for random)", ) - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) infilled = tile_fill_missing(image.copy(), seed=self.seed, tile_size=self.tile_size) infilled.paste(image, (0, 0), image.split()[-1]) - image_dto = context.services.images.create( - image=infilled, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=infilled) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( - "infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0" + "infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1" ) class InfillPatchMatchInvocation(BaseInvocation, WithMetadata): """Infills transparent areas of an image using the PatchMatch algorithm""" @@ -202,8 +176,8 @@ class InfillPatchMatchInvocation(BaseInvocation, WithMetadata): downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill") resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name).convert("RGBA") + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name).convert("RGBA") resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] @@ -228,77 +202,38 @@ def invoke(self, context: InvocationContext) -> ImageOutput: infilled.paste(image, (0, 0), mask=image.split()[-1]) # image.paste(infilled, (0, 0), mask=image.split()[-1]) - image_dto = context.services.images.create( - image=infilled, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=infilled) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) -@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0") +@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1") class LaMaInfillInvocation(BaseInvocation, WithMetadata): """Infills transparent areas of an image using the LaMa model""" image: ImageField = InputField(description="The image to infill") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) infilled = infill_lama(image.copy()) - image_dto = context.services.images.create( - image=infilled, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=infilled) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) -@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0") +@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1") class CV2InfillInvocation(BaseInvocation, WithMetadata): """Infills transparent areas of an image using OpenCV Inpainting""" image: ImageField = InputField(description="The image to infill") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) infilled = infill_cv2(image.copy()) - image_dto = context.services.images.create( - image=infilled, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=infilled) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index c01e0ed0fb2..b836be04b58 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -7,7 +7,6 @@ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) @@ -62,7 +61,7 @@ class IPAdapterOutput(BaseInvocationOutput): ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter") -@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.1.1") +@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.1.2") class IPAdapterInvocation(BaseInvocation): """Collects IP-Adapter info to pass to other nodes.""" @@ -93,9 +92,9 @@ def validate_begin_end_step_percent(self): validate_begin_end_step(self.begin_step_percent, self.end_step_percent) return self - def invoke(self, context: InvocationContext) -> IPAdapterOutput: + def invoke(self, context) -> IPAdapterOutput: # Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model. - ip_adapter_info = context.services.model_manager.model_info( + ip_adapter_info = context.models.get_info( self.ip_adapter_model.model_name, self.ip_adapter_model.base_model, ModelType.IPAdapter ) # HACK(ryand): This is bad for a couple of reasons: 1) we are bypassing the model manager to read the model @@ -104,7 +103,7 @@ def invoke(self, context: InvocationContext) -> IPAdapterOutput: # is currently messy due to differences between how the model info is generated when installing a model from # disk vs. downloading the model. image_encoder_model_id = get_ip_adapter_image_encoder_model_id( - os.path.join(context.services.configuration.get_config().models_path, ip_adapter_info["path"]) + os.path.join(context.config.get().models_path, ip_adapter_info["path"]) ) image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip() image_encoder_model = CLIPVisionModelField( diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 909c307481e..0127a6521e1 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -3,7 +3,7 @@ import math from contextlib import ExitStack from functools import singledispatchmethod -from typing import List, Literal, Optional, Union +from typing import TYPE_CHECKING, List, Literal, Optional, Union import einops import numpy as np @@ -23,21 +23,26 @@ from pydantic import field_validator from torchvision.transforms.functional import resize as tv_resize -from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType, WithMetadata +from invokeai.app.invocations.fields import ( + ConditioningField, + DenoiseMaskField, + FieldDescriptions, + ImageField, + Input, + InputField, + LatentsField, + OutputField, + UIType, + WithMetadata, +) from invokeai.app.invocations.ip_adapter import IPAdapterField from invokeai.app.invocations.primitives import ( - DenoiseMaskField, DenoiseMaskOutput, - ImageField, ImageOutput, - LatentsField, LatentsOutput, - build_latents_output, ) from invokeai.app.invocations.t2i_adapter import T2IAdapterField -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin from invokeai.app.util.controlnet_utils import prepare_control_image -from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus from invokeai.backend.model_management.models import ModelType, SilenceWarnings from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo @@ -59,14 +64,15 @@ from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) -from .compel import ConditioningField from .controlnet_image_processors import ControlField from .model import ModelInfo, UNetField, VaeField +if TYPE_CHECKING: + from invokeai.app.services.shared.invocation_context import InvocationContext + if choose_torch_device() == torch.device("mps"): from torch import mps @@ -102,7 +108,7 @@ class SchedulerInvocation(BaseInvocation): ui_type=UIType.Scheduler, ) - def invoke(self, context: InvocationContext) -> SchedulerOutput: + def invoke(self, context) -> SchedulerOutput: return SchedulerOutput(scheduler=self.scheduler) @@ -111,7 +117,7 @@ def invoke(self, context: InvocationContext) -> SchedulerOutput: title="Create Denoise Mask", tags=["mask", "denoise"], category="latents", - version="1.0.0", + version="1.0.1", ) class CreateDenoiseMaskInvocation(BaseInvocation): """Creates mask for denoising model run.""" @@ -137,9 +143,9 @@ def prep_mask_tensor(self, mask_image): return mask_tensor @torch.no_grad() - def invoke(self, context: InvocationContext) -> DenoiseMaskOutput: + def invoke(self, context) -> DenoiseMaskOutput: if self.image is not None: - image = context.services.images.get_pil_image(self.image.image_name) + image = context.images.get_pil(self.image.image_name) image = image_resized_to_grid_as_tensor(image.convert("RGB")) if image.dim() == 3: image = image.unsqueeze(0) @@ -147,47 +153,37 @@ def invoke(self, context: InvocationContext) -> DenoiseMaskOutput: image = None mask = self.prep_mask_tensor( - context.services.images.get_pil_image(self.mask.image_name), + context.images.get_pil(self.mask.image_name), ) if image is not None: - vae_info = context.services.model_manager.get_model( - **self.vae.vae.model_dump(), - context=context, - ) + vae_info = context.models.load(**self.vae.vae.model_dump()) img_mask = tv_resize(mask, image.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) masked_image = image * torch.where(img_mask < 0.5, 0.0, 1.0) # TODO: masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone()) - masked_latents_name = f"{context.graph_execution_state_id}__{self.id}_masked_latents" - context.services.latents.save(masked_latents_name, masked_latents) + masked_latents_name = context.latents.save(tensor=masked_latents) else: masked_latents_name = None - mask_name = f"{context.graph_execution_state_id}__{self.id}_mask" - context.services.latents.save(mask_name, mask) + mask_name = context.latents.save(tensor=mask) - return DenoiseMaskOutput( - denoise_mask=DenoiseMaskField( - mask_name=mask_name, - masked_latents_name=masked_latents_name, - ), + return DenoiseMaskOutput.build( + mask_name=mask_name, + masked_latents_name=masked_latents_name, ) def get_scheduler( - context: InvocationContext, + context: "InvocationContext", scheduler_info: ModelInfo, scheduler_name: str, seed: int, ) -> Scheduler: scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"]) - orig_scheduler_info = context.services.model_manager.get_model( - **scheduler_info.model_dump(), - context=context, - ) + orig_scheduler_info = context.models.load(**scheduler_info.model_dump()) with orig_scheduler_info as orig_scheduler: scheduler_config = orig_scheduler.config @@ -216,7 +212,7 @@ def get_scheduler( title="Denoise Latents", tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"], category="latents", - version="1.5.1", + version="1.5.2", ) class DenoiseLatentsInvocation(BaseInvocation): """Denoises noisy latents to decodable images""" @@ -302,34 +298,18 @@ def ge_one(cls, v): raise ValueError("cfg_scale must be greater than 1") return v - # TODO: pass this an emitter method or something? or a session for dispatching? - def dispatch_progress( - self, - context: InvocationContext, - source_node_id: str, - intermediate_state: PipelineIntermediateState, - base_model: BaseModelType, - ) -> None: - stable_diffusion_step_callback( - context=context, - intermediate_state=intermediate_state, - node=self.model_dump(), - source_node_id=source_node_id, - base_model=base_model, - ) - def get_conditioning_data( self, - context: InvocationContext, + context: "InvocationContext", scheduler, unet, seed, ) -> ConditioningData: - positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name) + positive_cond_data = context.conditioning.get(self.positive_conditioning.conditioning_name) c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) extra_conditioning_info = c.extra_conditioning - negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name) + negative_cond_data = context.conditioning.get(self.negative_conditioning.conditioning_name) uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) conditioning_data = ConditioningData( @@ -389,7 +369,7 @@ def __init__(self): def prep_control_data( self, - context: InvocationContext, + context: "InvocationContext", control_input: Union[ControlField, List[ControlField]], latents_shape: List[int], exit_stack: ExitStack, @@ -417,17 +397,16 @@ def prep_control_data( controlnet_data = [] for control_info in control_list: control_model = exit_stack.enter_context( - context.services.model_manager.get_model( + context.models.load( model_name=control_info.control_model.model_name, model_type=ModelType.ControlNet, base_model=control_info.control_model.base_model, - context=context, ) ) # control_models.append(control_model) control_image_field = control_info.image - input_image = context.services.images.get_pil_image(control_image_field.image_name) + input_image = context.images.get_pil(control_image_field.image_name) # self.image.image_type, self.image.image_name # FIXME: still need to test with different widths, heights, devices, dtypes # and add in batch_size, num_images_per_prompt? @@ -463,7 +442,7 @@ def prep_control_data( def prep_ip_adapter_data( self, - context: InvocationContext, + context: "InvocationContext", ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]], conditioning_data: ConditioningData, exit_stack: ExitStack, @@ -485,19 +464,17 @@ def prep_ip_adapter_data( conditioning_data.ip_adapter_conditioning = [] for single_ip_adapter in ip_adapter: ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context( - context.services.model_manager.get_model( + context.models.load( model_name=single_ip_adapter.ip_adapter_model.model_name, model_type=ModelType.IPAdapter, base_model=single_ip_adapter.ip_adapter_model.base_model, - context=context, ) ) - image_encoder_model_info = context.services.model_manager.get_model( + image_encoder_model_info = context.models.load( model_name=single_ip_adapter.image_encoder_model.model_name, model_type=ModelType.CLIPVision, base_model=single_ip_adapter.image_encoder_model.base_model, - context=context, ) # `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here. @@ -505,7 +482,7 @@ def prep_ip_adapter_data( if not isinstance(single_ipa_images, list): single_ipa_images = [single_ipa_images] - single_ipa_images = [context.services.images.get_pil_image(image.image_name) for image in single_ipa_images] + single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_images] # TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other # models are needed in memory. This would help to reduce peak memory utilization in low-memory environments. @@ -532,7 +509,7 @@ def prep_ip_adapter_data( def run_t2i_adapters( self, - context: InvocationContext, + context: "InvocationContext", t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]], latents_shape: list[int], do_classifier_free_guidance: bool, @@ -549,13 +526,12 @@ def run_t2i_adapters( t2i_adapter_data = [] for t2i_adapter_field in t2i_adapter: - t2i_adapter_model_info = context.services.model_manager.get_model( + t2i_adapter_model_info = context.models.load( model_name=t2i_adapter_field.t2i_adapter_model.model_name, model_type=ModelType.T2IAdapter, base_model=t2i_adapter_field.t2i_adapter_model.base_model, - context=context, ) - image = context.services.images.get_pil_image(t2i_adapter_field.image.image_name) + image = context.images.get_pil(t2i_adapter_field.image.image_name) # The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally. if t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusion1: @@ -642,30 +618,30 @@ def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_en return num_inference_steps, timesteps, init_timestep - def prep_inpaint_mask(self, context, latents): + def prep_inpaint_mask(self, context: "InvocationContext", latents): if self.denoise_mask is None: return None, None - mask = context.services.latents.get(self.denoise_mask.mask_name) + mask = context.latents.get(self.denoise_mask.mask_name) mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) if self.denoise_mask.masked_latents_name is not None: - masked_latents = context.services.latents.get(self.denoise_mask.masked_latents_name) + masked_latents = context.latents.get(self.denoise_mask.masked_latents_name) else: masked_latents = None return 1 - mask, masked_latents @torch.no_grad() - def invoke(self, context: InvocationContext) -> LatentsOutput: + def invoke(self, context) -> LatentsOutput: with SilenceWarnings(): # this quenches NSFW nag from diffusers seed = None noise = None if self.noise is not None: - noise = context.services.latents.get(self.noise.latents_name) + noise = context.latents.get(self.noise.latents_name) seed = self.noise.seed if self.latents is not None: - latents = context.services.latents.get(self.latents.latents_name) + latents = context.latents.get(self.latents.latents_name) if seed is None: seed = self.latents.seed @@ -691,27 +667,17 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: do_classifier_free_guidance=True, ) - # Get the source node id (we are invoking the prepared node) - graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) - source_node_id = graph_execution_state.prepared_source_mapping[self.id] - def step_callback(state: PipelineIntermediateState): - self.dispatch_progress(context, source_node_id, state, self.unet.unet.base_model) + context.util.sd_step_callback(state, self.unet.unet.base_model) def _lora_loader(): for lora in self.unet.loras: - lora_info = context.services.model_manager.get_model( - **lora.model_dump(exclude={"weight"}), - context=context, - ) + lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) yield (lora_info.context.model, lora.weight) del lora_info return - unet_info = context.services.model_manager.get_model( - **self.unet.unet.model_dump(), - context=context, - ) + unet_info = context.models.load(**self.unet.unet.model_dump()) with ( ExitStack() as exit_stack, ModelPatcher.apply_freeu(unet_info.context.model, self.unet.freeu_config), @@ -787,9 +753,8 @@ def _lora_loader(): if choose_torch_device() == torch.device("mps"): mps.empty_cache() - name = f"{context.graph_execution_state_id}__{self.id}" - context.services.latents.save(name, result_latents) - return build_latents_output(latents_name=name, latents=result_latents, seed=seed) + name = context.latents.save(tensor=result_latents) + return LatentsOutput.build(latents_name=name, latents=result_latents, seed=seed) @invocation( @@ -797,7 +762,7 @@ def _lora_loader(): title="Latents to Image", tags=["latents", "image", "vae", "l2i"], category="latents", - version="1.2.0", + version="1.2.1", ) class LatentsToImageInvocation(BaseInvocation, WithMetadata): """Generates an image from latents.""" @@ -814,13 +779,10 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata): fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32) @torch.no_grad() - def invoke(self, context: InvocationContext) -> ImageOutput: - latents = context.services.latents.get(self.latents.latents_name) + def invoke(self, context) -> ImageOutput: + latents = context.latents.get(self.latents.latents_name) - vae_info = context.services.model_manager.get_model( - **self.vae.vae.model_dump(), - context=context, - ) + vae_info = context.models.load(**self.vae.vae.model_dump()) with set_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae: latents = latents.to(vae.device) @@ -849,7 +811,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: vae.to(dtype=torch.float16) latents = latents.half() - if self.tiled or context.services.configuration.tiled_decode: + if self.tiled or context.config.get().tiled_decode: vae.enable_tiling() else: vae.disable_tiling() @@ -873,22 +835,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: if choose_torch_device() == torch.device("mps"): mps.empty_cache() - image_dto = context.services.images.create( - image=image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"] @@ -899,7 +848,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Resize Latents", tags=["latents", "resize"], category="latents", - version="1.0.0", + version="1.0.1", ) class ResizeLatentsInvocation(BaseInvocation): """Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8.""" @@ -921,8 +870,8 @@ class ResizeLatentsInvocation(BaseInvocation): mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode) antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) - def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.services.latents.get(self.latents.latents_name) + def invoke(self, context) -> LatentsOutput: + latents = context.latents.get(self.latents.latents_name) # TODO: device = choose_torch_device() @@ -940,10 +889,8 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: if device == torch.device("mps"): mps.empty_cache() - name = f"{context.graph_execution_state_id}__{self.id}" - # context.services.latents.set(name, resized_latents) - context.services.latents.save(name, resized_latents) - return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed) + name = context.latents.save(tensor=resized_latents) + return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed) @invocation( @@ -951,7 +898,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: title="Scale Latents", tags=["latents", "resize"], category="latents", - version="1.0.0", + version="1.0.1", ) class ScaleLatentsInvocation(BaseInvocation): """Scales latents by a given factor.""" @@ -964,8 +911,8 @@ class ScaleLatentsInvocation(BaseInvocation): mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode) antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) - def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.services.latents.get(self.latents.latents_name) + def invoke(self, context) -> LatentsOutput: + latents = context.latents.get(self.latents.latents_name) # TODO: device = choose_torch_device() @@ -984,10 +931,8 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: if device == torch.device("mps"): mps.empty_cache() - name = f"{context.graph_execution_state_id}__{self.id}" - # context.services.latents.set(name, resized_latents) - context.services.latents.save(name, resized_latents) - return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed) + name = context.latents.save(tensor=resized_latents) + return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed) @invocation( @@ -995,7 +940,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: title="Image to Latents", tags=["latents", "image", "vae", "i2l"], category="latents", - version="1.0.0", + version="1.0.1", ) class ImageToLatentsInvocation(BaseInvocation): """Encodes an image into latents.""" @@ -1055,13 +1000,10 @@ def vae_encode(vae_info, upcast, tiled, image_tensor): return latents @torch.no_grad() - def invoke(self, context: InvocationContext) -> LatentsOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> LatentsOutput: + image = context.images.get_pil(self.image.image_name) - vae_info = context.services.model_manager.get_model( - **self.vae.vae.model_dump(), - context=context, - ) + vae_info = context.models.load(**self.vae.vae.model_dump()) image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) if image_tensor.dim() == 3: @@ -1069,10 +1011,9 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor) - name = f"{context.graph_execution_state_id}__{self.id}" latents = latents.to("cpu") - context.services.latents.save(name, latents) - return build_latents_output(latents_name=name, latents=latents, seed=None) + name = context.latents.save(tensor=latents) + return LatentsOutput.build(latents_name=name, latents=latents, seed=None) @singledispatchmethod @staticmethod @@ -1092,7 +1033,7 @@ def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTenso title="Blend Latents", tags=["latents", "blend"], category="latents", - version="1.0.0", + version="1.0.1", ) class BlendLatentsInvocation(BaseInvocation): """Blend two latents using a given alpha. Latents must have same size.""" @@ -1107,9 +1048,9 @@ class BlendLatentsInvocation(BaseInvocation): ) alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha) - def invoke(self, context: InvocationContext) -> LatentsOutput: - latents_a = context.services.latents.get(self.latents_a.latents_name) - latents_b = context.services.latents.get(self.latents_b.latents_name) + def invoke(self, context) -> LatentsOutput: + latents_a = context.latents.get(self.latents_a.latents_name) + latents_b = context.latents.get(self.latents_b.latents_name) if latents_a.shape != latents_b.shape: raise Exception("Latents to blend must be the same size.") @@ -1163,10 +1104,8 @@ def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): if device == torch.device("mps"): mps.empty_cache() - name = f"{context.graph_execution_state_id}__{self.id}" - # context.services.latents.set(name, resized_latents) - context.services.latents.save(name, blended_latents) - return build_latents_output(latents_name=name, latents=blended_latents) + name = context.latents.save(tensor=blended_latents) + return LatentsOutput.build(latents_name=name, latents=blended_latents) # The Crop Latents node was copied from @skunkworxdark's implementation here: @@ -1176,7 +1115,7 @@ def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): title="Crop Latents", tags=["latents", "crop"], category="latents", - version="1.0.0", + version="1.0.1", ) # TODO(ryand): Named `CropLatentsCoreInvocation` to prevent a conflict with custom node `CropLatentsInvocation`. # Currently, if the class names conflict then 'GET /openapi.json' fails. @@ -1210,8 +1149,8 @@ class CropLatentsCoreInvocation(BaseInvocation): description="The height (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.", ) - def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.services.latents.get(self.latents.latents_name) + def invoke(self, context) -> LatentsOutput: + latents = context.latents.get(self.latents.latents_name) x1 = self.x // LATENT_SCALE_FACTOR y1 = self.y // LATENT_SCALE_FACTOR @@ -1220,10 +1159,9 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: cropped_latents = latents[..., y1:y2, x1:x2] - name = f"{context.graph_execution_state_id}__{self.id}" - context.services.latents.save(name, cropped_latents) + name = context.latents.save(tensor=cropped_latents) - return build_latents_output(latents_name=name, latents=cropped_latents) + return LatentsOutput.build(latents_name=name, latents=cropped_latents) @invocation_output("ideal_size_output") diff --git a/invokeai/app/invocations/math.py b/invokeai/app/invocations/math.py index 6ca53011f0b..d2dbf049816 100644 --- a/invokeai/app/invocations/math.py +++ b/invokeai/app/invocations/math.py @@ -8,7 +8,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, InputField from invokeai.app.invocations.primitives import FloatOutput, IntegerOutput -from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .baseinvocation import BaseInvocation, invocation @invocation("add", title="Add Integers", tags=["math", "add"], category="math", version="1.0.0") @@ -18,7 +18,7 @@ class AddInvocation(BaseInvocation): a: int = InputField(default=0, description=FieldDescriptions.num_1) b: int = InputField(default=0, description=FieldDescriptions.num_2) - def invoke(self, context: InvocationContext) -> IntegerOutput: + def invoke(self, context) -> IntegerOutput: return IntegerOutput(value=self.a + self.b) @@ -29,7 +29,7 @@ class SubtractInvocation(BaseInvocation): a: int = InputField(default=0, description=FieldDescriptions.num_1) b: int = InputField(default=0, description=FieldDescriptions.num_2) - def invoke(self, context: InvocationContext) -> IntegerOutput: + def invoke(self, context) -> IntegerOutput: return IntegerOutput(value=self.a - self.b) @@ -40,7 +40,7 @@ class MultiplyInvocation(BaseInvocation): a: int = InputField(default=0, description=FieldDescriptions.num_1) b: int = InputField(default=0, description=FieldDescriptions.num_2) - def invoke(self, context: InvocationContext) -> IntegerOutput: + def invoke(self, context) -> IntegerOutput: return IntegerOutput(value=self.a * self.b) @@ -51,7 +51,7 @@ class DivideInvocation(BaseInvocation): a: int = InputField(default=0, description=FieldDescriptions.num_1) b: int = InputField(default=0, description=FieldDescriptions.num_2) - def invoke(self, context: InvocationContext) -> IntegerOutput: + def invoke(self, context) -> IntegerOutput: return IntegerOutput(value=int(self.a / self.b)) @@ -69,7 +69,7 @@ class RandomIntInvocation(BaseInvocation): low: int = InputField(default=0, description=FieldDescriptions.inclusive_low) high: int = InputField(default=np.iinfo(np.int32).max, description=FieldDescriptions.exclusive_high) - def invoke(self, context: InvocationContext) -> IntegerOutput: + def invoke(self, context) -> IntegerOutput: return IntegerOutput(value=np.random.randint(self.low, self.high)) @@ -88,7 +88,7 @@ class RandomFloatInvocation(BaseInvocation): high: float = InputField(default=1.0, description=FieldDescriptions.exclusive_high) decimals: int = InputField(default=2, description=FieldDescriptions.decimal_places) - def invoke(self, context: InvocationContext) -> FloatOutput: + def invoke(self, context) -> FloatOutput: random_float = np.random.uniform(self.low, self.high) rounded_float = round(random_float, self.decimals) return FloatOutput(value=rounded_float) @@ -110,7 +110,7 @@ class FloatToIntegerInvocation(BaseInvocation): default="Nearest", description="The method to use for rounding" ) - def invoke(self, context: InvocationContext) -> IntegerOutput: + def invoke(self, context) -> IntegerOutput: if self.method == "Nearest": return IntegerOutput(value=round(self.value / self.multiple) * self.multiple) elif self.method == "Floor": @@ -128,7 +128,7 @@ class RoundInvocation(BaseInvocation): value: float = InputField(default=0, description="The float value") decimals: int = InputField(default=0, description="The number of decimal places") - def invoke(self, context: InvocationContext) -> FloatOutput: + def invoke(self, context) -> FloatOutput: return FloatOutput(value=round(self.value, self.decimals)) @@ -196,7 +196,7 @@ def no_unrepresentable_results(cls, v: int, info: ValidationInfo): raise ValueError("Result of exponentiation is not an integer") return v - def invoke(self, context: InvocationContext) -> IntegerOutput: + def invoke(self, context) -> IntegerOutput: # Python doesn't support switch statements until 3.10, but InvokeAI supports back to 3.9 if self.operation == "ADD": return IntegerOutput(value=self.a + self.b) @@ -270,7 +270,7 @@ def no_unrepresentable_results(cls, v: float, info: ValidationInfo): raise ValueError("Root operation resulted in a complex number") return v - def invoke(self, context: InvocationContext) -> FloatOutput: + def invoke(self, context) -> FloatOutput: # Python doesn't support switch statements until 3.10, but InvokeAI supports back to 3.9 if self.operation == "ADD": return FloatOutput(value=self.a + self.b) diff --git a/invokeai/app/invocations/metadata.py b/invokeai/app/invocations/metadata.py index 399e217dc17..9d74abd8c12 100644 --- a/invokeai/app/invocations/metadata.py +++ b/invokeai/app/invocations/metadata.py @@ -5,15 +5,20 @@ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) from invokeai.app.invocations.controlnet_image_processors import ControlField -from invokeai.app.invocations.fields import FieldDescriptions, InputField, MetadataField, OutputField, UIType +from invokeai.app.invocations.fields import ( + FieldDescriptions, + ImageField, + InputField, + MetadataField, + OutputField, + UIType, +) from invokeai.app.invocations.ip_adapter import IPAdapterModelField from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField -from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.t2i_adapter import T2IAdapterField from ...version import __version__ @@ -59,7 +64,7 @@ class MetadataItemInvocation(BaseInvocation): label: str = InputField(description=FieldDescriptions.metadata_item_label) value: Any = InputField(description=FieldDescriptions.metadata_item_value, ui_type=UIType.Any) - def invoke(self, context: InvocationContext) -> MetadataItemOutput: + def invoke(self, context) -> MetadataItemOutput: return MetadataItemOutput(item=MetadataItemField(label=self.label, value=self.value)) @@ -76,7 +81,7 @@ class MetadataInvocation(BaseInvocation): description=FieldDescriptions.metadata_item_polymorphic ) - def invoke(self, context: InvocationContext) -> MetadataOutput: + def invoke(self, context) -> MetadataOutput: if isinstance(self.items, MetadataItemField): # single metadata item data = {self.items.label: self.items.value} @@ -95,7 +100,7 @@ class MergeMetadataInvocation(BaseInvocation): collection: list[MetadataField] = InputField(description=FieldDescriptions.metadata_collection) - def invoke(self, context: InvocationContext) -> MetadataOutput: + def invoke(self, context) -> MetadataOutput: data = {} for item in self.collection: data.update(item.model_dump()) @@ -213,7 +218,7 @@ class CoreMetadataInvocation(BaseInvocation): description="The start value used for refiner denoising", ) - def invoke(self, context: InvocationContext) -> MetadataOutput: + def invoke(self, context) -> MetadataOutput: """Collects and outputs a CoreMetadata object""" return MetadataOutput( diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index c710c9761b0..f81e559e446 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -10,7 +10,6 @@ from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) @@ -102,7 +101,7 @@ class LoRAModelField(BaseModel): title="Main Model", tags=["model"], category="model", - version="1.0.0", + version="1.0.1", ) class MainModelLoaderInvocation(BaseInvocation): """Loads a main model, outputting its submodels.""" @@ -110,13 +109,13 @@ class MainModelLoaderInvocation(BaseInvocation): model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct) # TODO: precision? - def invoke(self, context: InvocationContext) -> ModelLoaderOutput: + def invoke(self, context) -> ModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name model_type = ModelType.Main # TODO: not found exceptions - if not context.services.model_manager.model_exists( + if not context.models.exists( model_name=model_name, base_model=base_model, model_type=model_type, @@ -203,7 +202,7 @@ class LoraLoaderOutput(BaseInvocationOutput): clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP") -@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.0") +@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.1") class LoraLoaderInvocation(BaseInvocation): """Apply selected lora to unet and text_encoder.""" @@ -222,14 +221,14 @@ class LoraLoaderInvocation(BaseInvocation): title="CLIP", ) - def invoke(self, context: InvocationContext) -> LoraLoaderOutput: + def invoke(self, context) -> LoraLoaderOutput: if self.lora is None: raise Exception("No LoRA provided") base_model = self.lora.base_model lora_name = self.lora.model_name - if not context.services.model_manager.model_exists( + if not context.models.exists( base_model=base_model, model_name=lora_name, model_type=ModelType.Lora, @@ -285,7 +284,7 @@ class SDXLLoraLoaderOutput(BaseInvocationOutput): title="SDXL LoRA", tags=["lora", "model"], category="model", - version="1.0.0", + version="1.0.1", ) class SDXLLoraLoaderInvocation(BaseInvocation): """Apply selected lora to unet and text_encoder.""" @@ -311,14 +310,14 @@ class SDXLLoraLoaderInvocation(BaseInvocation): title="CLIP 2", ) - def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: + def invoke(self, context) -> SDXLLoraLoaderOutput: if self.lora is None: raise Exception("No LoRA provided") base_model = self.lora.base_model lora_name = self.lora.model_name - if not context.services.model_manager.model_exists( + if not context.models.exists( base_model=base_model, model_name=lora_name, model_type=ModelType.Lora, @@ -384,7 +383,7 @@ class VAEModelField(BaseModel): model_config = ConfigDict(protected_namespaces=()) -@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.0") +@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.1") class VaeLoaderInvocation(BaseInvocation): """Loads a VAE model, outputting a VaeLoaderOutput""" @@ -394,12 +393,12 @@ class VaeLoaderInvocation(BaseInvocation): title="VAE", ) - def invoke(self, context: InvocationContext) -> VAEOutput: + def invoke(self, context) -> VAEOutput: base_model = self.vae_model.base_model model_name = self.vae_model.model_name model_type = ModelType.Vae - if not context.services.model_manager.model_exists( + if not context.models.exists( base_model=base_model, model_name=model_name, model_type=model_type, @@ -449,7 +448,7 @@ class SeamlessModeInvocation(BaseInvocation): seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless") seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless") - def invoke(self, context: InvocationContext) -> SeamlessModeOutput: + def invoke(self, context) -> SeamlessModeOutput: # Conditionally append 'x' and 'y' based on seamless_x and seamless_y unet = copy.deepcopy(self.unet) vae = copy.deepcopy(self.vae) @@ -485,6 +484,6 @@ class FreeUInvocation(BaseInvocation): s1: float = InputField(default=0.9, ge=-1, le=3, description=FieldDescriptions.freeu_s1) s2: float = InputField(default=0.2, ge=-1, le=3, description=FieldDescriptions.freeu_s2) - def invoke(self, context: InvocationContext) -> UNetOutput: + def invoke(self, context) -> UNetOutput: self.unet.freeu_config = FreeUConfig(s1=self.s1, s2=self.s2, b1=self.b1, b2=self.b2) return UNetOutput(unet=self.unet) diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index 2e717ac561b..41641152f04 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -4,15 +4,13 @@ import torch from pydantic import field_validator -from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField -from invokeai.app.invocations.latent import LatentsField +from invokeai.app.invocations.fields import FieldDescriptions, InputField, LatentsField, OutputField from invokeai.app.util.misc import SEED_MAX from ...backend.util.devices import choose_torch_device, torch_dtype from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) @@ -67,13 +65,13 @@ class NoiseOutput(BaseInvocationOutput): width: int = OutputField(description=FieldDescriptions.width) height: int = OutputField(description=FieldDescriptions.height) - -def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int): - return NoiseOutput( - noise=LatentsField(latents_name=latents_name, seed=seed), - width=latents.size()[3] * 8, - height=latents.size()[2] * 8, - ) + @classmethod + def build(cls, latents_name: str, latents: torch.Tensor, seed: int) -> "NoiseOutput": + return cls( + noise=LatentsField(latents_name=latents_name, seed=seed), + width=latents.size()[3] * 8, + height=latents.size()[2] * 8, + ) @invocation( @@ -114,7 +112,7 @@ def modulo_seed(cls, v): """Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range.""" return v % (SEED_MAX + 1) - def invoke(self, context: InvocationContext) -> NoiseOutput: + def invoke(self, context) -> NoiseOutput: noise = get_noise( width=self.width, height=self.height, @@ -122,6 +120,5 @@ def invoke(self, context: InvocationContext) -> NoiseOutput: seed=self.seed, use_cpu=self.use_cpu, ) - name = f"{context.graph_execution_state_id}__{self.id}" - context.services.latents.save(name, noise) - return build_noise_output(latents_name=name, latents=noise, seed=self.seed) + name = context.latents.save(tensor=noise) + return NoiseOutput.build(latents_name=name, latents=noise, seed=self.seed) diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py index b43d7eaef2c..3f8e6669ab8 100644 --- a/invokeai/app/invocations/onnx.py +++ b/invokeai/app/invocations/onnx.py @@ -37,7 +37,7 @@ invocation_output, ) from .controlnet_image_processors import ControlField -from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, build_latents_output, get_scheduler +from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, get_scheduler from .model import ClipField, ModelInfo, UNetField, VaeField ORT_TO_NP_TYPE = { @@ -63,7 +63,7 @@ class ONNXPromptInvocation(BaseInvocation): prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea) clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection) - def invoke(self, context: InvocationContext) -> ConditioningOutput: + def invoke(self, context) -> ConditioningOutput: tokenizer_info = context.services.model_manager.get_model( **self.clip.tokenizer.model_dump(), ) @@ -201,7 +201,7 @@ def ge_one(cls, v): # based on # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375 - def invoke(self, context: InvocationContext) -> LatentsOutput: + def invoke(self, context) -> LatentsOutput: c, _ = context.services.latents.get(self.positive_conditioning.conditioning_name) uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name) graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) @@ -342,7 +342,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata): ) # tiled: bool = InputField(default=False, description="Decode latents by overlaping tiles(less memory consumption)") - def invoke(self, context: InvocationContext) -> ImageOutput: + def invoke(self, context) -> ImageOutput: latents = context.services.latents.get(self.latents.latents_name) if self.vae.vae.submodel != SubModelType.VaeDecoder: @@ -417,7 +417,7 @@ class OnnxModelLoaderInvocation(BaseInvocation): description=FieldDescriptions.onnx_main_model, input=Input.Direct, ui_type=UIType.ONNXModel ) - def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput: + def invoke(self, context) -> ONNXModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name model_type = ModelType.ONNX diff --git a/invokeai/app/invocations/param_easing.py b/invokeai/app/invocations/param_easing.py index dab9c3dc0f4..bf59e87d270 100644 --- a/invokeai/app/invocations/param_easing.py +++ b/invokeai/app/invocations/param_easing.py @@ -41,7 +41,7 @@ from invokeai.app.invocations.primitives import FloatCollectionOutput -from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .baseinvocation import BaseInvocation, invocation from .fields import InputField @@ -62,7 +62,7 @@ class FloatLinearRangeInvocation(BaseInvocation): description="number of values to interpolate over (including start and stop)", ) - def invoke(self, context: InvocationContext) -> FloatCollectionOutput: + def invoke(self, context) -> FloatCollectionOutput: param_list = list(np.linspace(self.start, self.stop, self.steps)) return FloatCollectionOutput(collection=param_list) @@ -110,7 +110,7 @@ def invoke(self, context: InvocationContext) -> FloatCollectionOutput: title="Step Param Easing", tags=["step", "easing"], category="step", - version="1.0.0", + version="1.0.1", ) class StepParamEasingInvocation(BaseInvocation): """Experimental per-step parameter easing for denoising steps""" @@ -130,7 +130,7 @@ class StepParamEasingInvocation(BaseInvocation): # alt_mirror: bool = InputField(default=False, description="alternative mirroring by dual easing") show_easing_plot: bool = InputField(default=False, description="show easing plot") - def invoke(self, context: InvocationContext) -> FloatCollectionOutput: + def invoke(self, context) -> FloatCollectionOutput: log_diagnostics = False # convert from start_step_percent to nearest step <= (steps * start_step_percent) # start_step = int(np.floor(self.num_steps * self.start_step_percent)) @@ -149,19 +149,19 @@ def invoke(self, context: InvocationContext) -> FloatCollectionOutput: postlist = list(num_poststeps * [self.post_end_value]) if log_diagnostics: - context.services.logger.debug("start_step: " + str(start_step)) - context.services.logger.debug("end_step: " + str(end_step)) - context.services.logger.debug("num_easing_steps: " + str(num_easing_steps)) - context.services.logger.debug("num_presteps: " + str(num_presteps)) - context.services.logger.debug("num_poststeps: " + str(num_poststeps)) - context.services.logger.debug("prelist size: " + str(len(prelist))) - context.services.logger.debug("postlist size: " + str(len(postlist))) - context.services.logger.debug("prelist: " + str(prelist)) - context.services.logger.debug("postlist: " + str(postlist)) + context.logger.debug("start_step: " + str(start_step)) + context.logger.debug("end_step: " + str(end_step)) + context.logger.debug("num_easing_steps: " + str(num_easing_steps)) + context.logger.debug("num_presteps: " + str(num_presteps)) + context.logger.debug("num_poststeps: " + str(num_poststeps)) + context.logger.debug("prelist size: " + str(len(prelist))) + context.logger.debug("postlist size: " + str(len(postlist))) + context.logger.debug("prelist: " + str(prelist)) + context.logger.debug("postlist: " + str(postlist)) easing_class = EASING_FUNCTIONS_MAP[self.easing] if log_diagnostics: - context.services.logger.debug("easing class: " + str(easing_class)) + context.logger.debug("easing class: " + str(easing_class)) easing_list = [] if self.mirror: # "expected" mirroring # if number of steps is even, squeeze duration down to (number_of_steps)/2 @@ -172,7 +172,7 @@ def invoke(self, context: InvocationContext) -> FloatCollectionOutput: base_easing_duration = int(np.ceil(num_easing_steps / 2.0)) if log_diagnostics: - context.services.logger.debug("base easing duration: " + str(base_easing_duration)) + context.logger.debug("base easing duration: " + str(base_easing_duration)) even_num_steps = num_easing_steps % 2 == 0 # even number of steps easing_function = easing_class( start=self.start_value, @@ -184,14 +184,14 @@ def invoke(self, context: InvocationContext) -> FloatCollectionOutput: easing_val = easing_function.ease(step_index) base_easing_vals.append(easing_val) if log_diagnostics: - context.services.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val)) + context.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val)) if even_num_steps: mirror_easing_vals = list(reversed(base_easing_vals)) else: mirror_easing_vals = list(reversed(base_easing_vals[0:-1])) if log_diagnostics: - context.services.logger.debug("base easing vals: " + str(base_easing_vals)) - context.services.logger.debug("mirror easing vals: " + str(mirror_easing_vals)) + context.logger.debug("base easing vals: " + str(base_easing_vals)) + context.logger.debug("mirror easing vals: " + str(mirror_easing_vals)) easing_list = base_easing_vals + mirror_easing_vals # FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely @@ -226,12 +226,12 @@ def invoke(self, context: InvocationContext) -> FloatCollectionOutput: step_val = easing_function.ease(step_index) easing_list.append(step_val) if log_diagnostics: - context.services.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val)) + context.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val)) if log_diagnostics: - context.services.logger.debug("prelist size: " + str(len(prelist))) - context.services.logger.debug("easing_list size: " + str(len(easing_list))) - context.services.logger.debug("postlist size: " + str(len(postlist))) + context.logger.debug("prelist size: " + str(len(prelist))) + context.logger.debug("easing_list size: " + str(len(easing_list))) + context.logger.debug("postlist size: " + str(len(postlist))) param_list = prelist + easing_list + postlist diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index 22f03454a55..ee04345eed8 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -1,16 +1,26 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) -from typing import Optional, Tuple +from typing import Optional import torch -from pydantic import BaseModel, Field -from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent +from invokeai.app.invocations.fields import ( + ColorField, + ConditioningField, + DenoiseMaskField, + FieldDescriptions, + ImageField, + Input, + InputField, + LatentsField, + OutputField, + UIComponent, +) +from invokeai.app.services.images.images_common import ImageDTO from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) @@ -49,7 +59,7 @@ class BooleanInvocation(BaseInvocation): value: bool = InputField(default=False, description="The boolean value") - def invoke(self, context: InvocationContext) -> BooleanOutput: + def invoke(self, context) -> BooleanOutput: return BooleanOutput(value=self.value) @@ -65,7 +75,7 @@ class BooleanCollectionInvocation(BaseInvocation): collection: list[bool] = InputField(default=[], description="The collection of boolean values") - def invoke(self, context: InvocationContext) -> BooleanCollectionOutput: + def invoke(self, context) -> BooleanCollectionOutput: return BooleanCollectionOutput(collection=self.collection) @@ -98,7 +108,7 @@ class IntegerInvocation(BaseInvocation): value: int = InputField(default=0, description="The integer value") - def invoke(self, context: InvocationContext) -> IntegerOutput: + def invoke(self, context) -> IntegerOutput: return IntegerOutput(value=self.value) @@ -114,7 +124,7 @@ class IntegerCollectionInvocation(BaseInvocation): collection: list[int] = InputField(default=[], description="The collection of integer values") - def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: + def invoke(self, context) -> IntegerCollectionOutput: return IntegerCollectionOutput(collection=self.collection) @@ -145,7 +155,7 @@ class FloatInvocation(BaseInvocation): value: float = InputField(default=0.0, description="The float value") - def invoke(self, context: InvocationContext) -> FloatOutput: + def invoke(self, context) -> FloatOutput: return FloatOutput(value=self.value) @@ -161,7 +171,7 @@ class FloatCollectionInvocation(BaseInvocation): collection: list[float] = InputField(default=[], description="The collection of float values") - def invoke(self, context: InvocationContext) -> FloatCollectionOutput: + def invoke(self, context) -> FloatCollectionOutput: return FloatCollectionOutput(collection=self.collection) @@ -192,7 +202,7 @@ class StringInvocation(BaseInvocation): value: str = InputField(default="", description="The string value", ui_component=UIComponent.Textarea) - def invoke(self, context: InvocationContext) -> StringOutput: + def invoke(self, context) -> StringOutput: return StringOutput(value=self.value) @@ -208,7 +218,7 @@ class StringCollectionInvocation(BaseInvocation): collection: list[str] = InputField(default=[], description="The collection of string values") - def invoke(self, context: InvocationContext) -> StringCollectionOutput: + def invoke(self, context) -> StringCollectionOutput: return StringCollectionOutput(collection=self.collection) @@ -217,18 +227,6 @@ def invoke(self, context: InvocationContext) -> StringCollectionOutput: # region Image -class ImageField(BaseModel): - """An image primitive field""" - - image_name: str = Field(description="The name of the image") - - -class BoardField(BaseModel): - """A board primitive field""" - - board_id: str = Field(description="The id of the board") - - @invocation_output("image_output") class ImageOutput(BaseInvocationOutput): """Base class for nodes that output a single image""" @@ -237,6 +235,14 @@ class ImageOutput(BaseInvocationOutput): width: int = OutputField(description="The width of the image in pixels") height: int = OutputField(description="The height of the image in pixels") + @classmethod + def build(cls, image_dto: ImageDTO) -> "ImageOutput": + return cls( + image=ImageField(image_name=image_dto.image_name), + width=image_dto.width, + height=image_dto.height, + ) + @invocation_output("image_collection_output") class ImageCollectionOutput(BaseInvocationOutput): @@ -247,7 +253,7 @@ class ImageCollectionOutput(BaseInvocationOutput): ) -@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives", version="1.0.0") +@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives", version="1.0.1") class ImageInvocation( BaseInvocation, ): @@ -255,8 +261,8 @@ class ImageInvocation( image: ImageField = InputField(description="The image to load") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) return ImageOutput( image=ImageField(image_name=self.image.image_name), @@ -277,7 +283,7 @@ class ImageCollectionInvocation(BaseInvocation): collection: list[ImageField] = InputField(description="The collection of image values") - def invoke(self, context: InvocationContext) -> ImageCollectionOutput: + def invoke(self, context) -> ImageCollectionOutput: return ImageCollectionOutput(collection=self.collection) @@ -286,32 +292,24 @@ def invoke(self, context: InvocationContext) -> ImageCollectionOutput: # region DenoiseMask -class DenoiseMaskField(BaseModel): - """An inpaint mask field""" - - mask_name: str = Field(description="The name of the mask image") - masked_latents_name: Optional[str] = Field(default=None, description="The name of the masked image latents") - - @invocation_output("denoise_mask_output") class DenoiseMaskOutput(BaseInvocationOutput): """Base class for nodes that output a single image""" denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run") + @classmethod + def build(cls, mask_name: str, masked_latents_name: Optional[str] = None) -> "DenoiseMaskOutput": + return cls( + denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name), + ) + # endregion # region Latents -class LatentsField(BaseModel): - """A latents tensor primitive field""" - - latents_name: str = Field(description="The name of the latents") - seed: Optional[int] = Field(default=None, description="Seed used to generate this latents") - - @invocation_output("latents_output") class LatentsOutput(BaseInvocationOutput): """Base class for nodes that output a single latents tensor""" @@ -322,6 +320,14 @@ class LatentsOutput(BaseInvocationOutput): width: int = OutputField(description=FieldDescriptions.width) height: int = OutputField(description=FieldDescriptions.height) + @classmethod + def build(cls, latents_name: str, latents: torch.Tensor, seed: Optional[int] = None) -> "LatentsOutput": + return cls( + latents=LatentsField(latents_name=latents_name, seed=seed), + width=latents.size()[3] * 8, + height=latents.size()[2] * 8, + ) + @invocation_output("latents_collection_output") class LatentsCollectionOutput(BaseInvocationOutput): @@ -333,17 +339,17 @@ class LatentsCollectionOutput(BaseInvocationOutput): @invocation( - "latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives", version="1.0.0" + "latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives", version="1.0.1" ) class LatentsInvocation(BaseInvocation): """A latents tensor primitive value""" latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection) - def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.services.latents.get(self.latents.latents_name) + def invoke(self, context) -> LatentsOutput: + latents = context.latents.get(self.latents.latents_name) - return build_latents_output(self.latents.latents_name, latents) + return LatentsOutput.build(self.latents.latents_name, latents) @invocation( @@ -360,35 +366,15 @@ class LatentsCollectionInvocation(BaseInvocation): description="The collection of latents tensors", ) - def invoke(self, context: InvocationContext) -> LatentsCollectionOutput: + def invoke(self, context) -> LatentsCollectionOutput: return LatentsCollectionOutput(collection=self.collection) -def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int] = None): - return LatentsOutput( - latents=LatentsField(latents_name=latents_name, seed=seed), - width=latents.size()[3] * 8, - height=latents.size()[2] * 8, - ) - - # endregion # region Color -class ColorField(BaseModel): - """A color primitive field""" - - r: int = Field(ge=0, le=255, description="The red component") - g: int = Field(ge=0, le=255, description="The green component") - b: int = Field(ge=0, le=255, description="The blue component") - a: int = Field(ge=0, le=255, description="The alpha component") - - def tuple(self) -> Tuple[int, int, int, int]: - return (self.r, self.g, self.b, self.a) - - @invocation_output("color_output") class ColorOutput(BaseInvocationOutput): """Base class for nodes that output a single color""" @@ -411,7 +397,7 @@ class ColorInvocation(BaseInvocation): color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color value") - def invoke(self, context: InvocationContext) -> ColorOutput: + def invoke(self, context) -> ColorOutput: return ColorOutput(color=self.color) @@ -420,18 +406,16 @@ def invoke(self, context: InvocationContext) -> ColorOutput: # region Conditioning -class ConditioningField(BaseModel): - """A conditioning tensor primitive value""" - - conditioning_name: str = Field(description="The name of conditioning tensor") - - @invocation_output("conditioning_output") class ConditioningOutput(BaseInvocationOutput): """Base class for nodes that output a single conditioning tensor""" conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond) + @classmethod + def build(cls, conditioning_name: str) -> "ConditioningOutput": + return cls(conditioning=ConditioningField(conditioning_name=conditioning_name)) + @invocation_output("conditioning_collection_output") class ConditioningCollectionOutput(BaseInvocationOutput): @@ -454,7 +438,7 @@ class ConditioningInvocation(BaseInvocation): conditioning: ConditioningField = InputField(description=FieldDescriptions.cond, input=Input.Connection) - def invoke(self, context: InvocationContext) -> ConditioningOutput: + def invoke(self, context) -> ConditioningOutput: return ConditioningOutput(conditioning=self.conditioning) @@ -473,7 +457,7 @@ class ConditioningCollectionInvocation(BaseInvocation): description="The collection of conditioning tensors", ) - def invoke(self, context: InvocationContext) -> ConditioningCollectionOutput: + def invoke(self, context) -> ConditioningCollectionOutput: return ConditioningCollectionOutput(collection=self.collection) diff --git a/invokeai/app/invocations/prompt.py b/invokeai/app/invocations/prompt.py index 94b4a217ae7..4f5ef43a568 100644 --- a/invokeai/app/invocations/prompt.py +++ b/invokeai/app/invocations/prompt.py @@ -7,7 +7,7 @@ from invokeai.app.invocations.primitives import StringCollectionOutput -from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .baseinvocation import BaseInvocation, invocation from .fields import InputField, UIComponent @@ -29,7 +29,7 @@ class DynamicPromptInvocation(BaseInvocation): max_prompts: int = InputField(default=1, description="The number of prompts to generate") combinatorial: bool = InputField(default=False, description="Whether to use the combinatorial generator") - def invoke(self, context: InvocationContext) -> StringCollectionOutput: + def invoke(self, context) -> StringCollectionOutput: if self.combinatorial: generator = CombinatorialPromptGenerator() prompts = generator.generate(self.prompt, max_prompts=self.max_prompts) @@ -91,7 +91,7 @@ def promptsFromFile( break return prompts - def invoke(self, context: InvocationContext) -> StringCollectionOutput: + def invoke(self, context) -> StringCollectionOutput: prompts = self.promptsFromFile( self.file_path, self.pre_prompt, diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 62df5bc8047..75a526cfff6 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -4,7 +4,6 @@ from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) @@ -30,7 +29,7 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput): vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE") -@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.0") +@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.1") class SDXLModelLoaderInvocation(BaseInvocation): """Loads an sdxl base model, outputting its submodels.""" @@ -39,13 +38,13 @@ class SDXLModelLoaderInvocation(BaseInvocation): ) # TODO: precision? - def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput: + def invoke(self, context) -> SDXLModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name model_type = ModelType.Main # TODO: not found exceptions - if not context.services.model_manager.model_exists( + if not context.models.exists( model_name=model_name, base_model=base_model, model_type=model_type, @@ -116,7 +115,7 @@ def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput: title="SDXL Refiner Model", tags=["model", "sdxl", "refiner"], category="model", - version="1.0.0", + version="1.0.1", ) class SDXLRefinerModelLoaderInvocation(BaseInvocation): """Loads an sdxl refiner model, outputting its submodels.""" @@ -128,13 +127,13 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation): ) # TODO: precision? - def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput: + def invoke(self, context) -> SDXLRefinerModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name model_type = ModelType.Main # TODO: not found exceptions - if not context.services.model_manager.model_exists( + if not context.models.exists( model_name=model_name, base_model=base_model, model_type=model_type, diff --git a/invokeai/app/invocations/strings.py b/invokeai/app/invocations/strings.py index ccbc2f6d924..a4c92d9de56 100644 --- a/invokeai/app/invocations/strings.py +++ b/invokeai/app/invocations/strings.py @@ -5,7 +5,6 @@ from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) @@ -33,7 +32,7 @@ class StringSplitNegInvocation(BaseInvocation): string: str = InputField(default="", description="String to split", ui_component=UIComponent.Textarea) - def invoke(self, context: InvocationContext) -> StringPosNegOutput: + def invoke(self, context) -> StringPosNegOutput: p_string = "" n_string = "" brackets_depth = 0 @@ -77,7 +76,7 @@ class StringSplitInvocation(BaseInvocation): default="", description="Delimiter to spilt with. blank will split on the first whitespace" ) - def invoke(self, context: InvocationContext) -> String2Output: + def invoke(self, context) -> String2Output: result = self.string.split(self.delimiter, 1) if len(result) == 2: part1, part2 = result @@ -95,7 +94,7 @@ class StringJoinInvocation(BaseInvocation): string_left: str = InputField(default="", description="String Left", ui_component=UIComponent.Textarea) string_right: str = InputField(default="", description="String Right", ui_component=UIComponent.Textarea) - def invoke(self, context: InvocationContext) -> StringOutput: + def invoke(self, context) -> StringOutput: return StringOutput(value=((self.string_left or "") + (self.string_right or ""))) @@ -107,7 +106,7 @@ class StringJoinThreeInvocation(BaseInvocation): string_middle: str = InputField(default="", description="String Middle", ui_component=UIComponent.Textarea) string_right: str = InputField(default="", description="String Right", ui_component=UIComponent.Textarea) - def invoke(self, context: InvocationContext) -> StringOutput: + def invoke(self, context) -> StringOutput: return StringOutput(value=((self.string_left or "") + (self.string_middle or "") + (self.string_right or ""))) @@ -126,7 +125,7 @@ class StringReplaceInvocation(BaseInvocation): default=False, description="Use search string as a regex expression (non regex is case insensitive)" ) - def invoke(self, context: InvocationContext) -> StringOutput: + def invoke(self, context) -> StringOutput: pattern = self.search_string or "" new_string = self.string or "" if len(pattern) > 0: diff --git a/invokeai/app/invocations/t2i_adapter.py b/invokeai/app/invocations/t2i_adapter.py index 66ac87c37b8..74a098a501c 100644 --- a/invokeai/app/invocations/t2i_adapter.py +++ b/invokeai/app/invocations/t2i_adapter.py @@ -5,13 +5,11 @@ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES -from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField -from invokeai.app.invocations.primitives import ImageField +from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.backend.model_management.models.base import BaseModelType @@ -91,7 +89,7 @@ def validate_begin_end_step_percent(self): validate_begin_end_step(self.begin_step_percent, self.end_step_percent) return self - def invoke(self, context: InvocationContext) -> T2IAdapterOutput: + def invoke(self, context) -> T2IAdapterOutput: return T2IAdapterOutput( t2i_adapter=T2IAdapterField( image=self.image, diff --git a/invokeai/app/invocations/tiles.py b/invokeai/app/invocations/tiles.py index bdc23ef6edd..dd34c3dc093 100644 --- a/invokeai/app/invocations/tiles.py +++ b/invokeai/app/invocations/tiles.py @@ -8,13 +8,12 @@ BaseInvocation, BaseInvocationOutput, Classification, - InvocationContext, + WithMetadata, invocation, invocation_output, ) -from invokeai.app.invocations.fields import Input, InputField, OutputField, WithMetadata -from invokeai.app.invocations.primitives import ImageField, ImageOutput -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin +from invokeai.app.invocations.fields import ImageField, Input, InputField, OutputField +from invokeai.app.invocations.primitives import ImageOutput from invokeai.backend.tiles.tiles import ( calc_tiles_even_split, calc_tiles_min_overlap, @@ -58,7 +57,7 @@ class CalculateImageTilesInvocation(BaseInvocation): description="The target overlap, in pixels, between adjacent tiles. Adjacent tiles will overlap by at least this amount", ) - def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput: + def invoke(self, context) -> CalculateImageTilesOutput: tiles = calc_tiles_with_overlap( image_height=self.image_height, image_width=self.image_width, @@ -101,7 +100,7 @@ class CalculateImageTilesEvenSplitInvocation(BaseInvocation): description="The overlap, in pixels, between adjacent tiles.", ) - def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput: + def invoke(self, context) -> CalculateImageTilesOutput: tiles = calc_tiles_even_split( image_height=self.image_height, image_width=self.image_width, @@ -131,7 +130,7 @@ class CalculateImageTilesMinimumOverlapInvocation(BaseInvocation): tile_height: int = InputField(ge=1, default=576, description="The tile height, in pixels.") min_overlap: int = InputField(default=128, ge=0, description="Minimum overlap between adjacent tiles, in pixels.") - def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput: + def invoke(self, context) -> CalculateImageTilesOutput: tiles = calc_tiles_min_overlap( image_height=self.image_height, image_width=self.image_width, @@ -176,7 +175,7 @@ class TileToPropertiesInvocation(BaseInvocation): tile: Tile = InputField(description="The tile to split into properties.") - def invoke(self, context: InvocationContext) -> TileToPropertiesOutput: + def invoke(self, context) -> TileToPropertiesOutput: return TileToPropertiesOutput( coords_left=self.tile.coords.left, coords_right=self.tile.coords.right, @@ -213,7 +212,7 @@ class PairTileImageInvocation(BaseInvocation): image: ImageField = InputField(description="The tile image.") tile: Tile = InputField(description="The tile properties.") - def invoke(self, context: InvocationContext) -> PairTileImageOutput: + def invoke(self, context) -> PairTileImageOutput: return PairTileImageOutput( tile_with_image=TileWithImage( tile=self.tile, @@ -249,7 +248,7 @@ class MergeTilesToImageInvocation(BaseInvocation, WithMetadata): description="The amount to blend adjacent tiles in pixels. Must be <= the amount of overlap between adjacent tiles.", ) - def invoke(self, context: InvocationContext) -> ImageOutput: + def invoke(self, context) -> ImageOutput: images = [twi.image for twi in self.tiles_with_images] tiles = [twi.tile for twi in self.tiles_with_images] @@ -265,7 +264,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: # existed in memory at an earlier point in the graph. tile_np_images: list[np.ndarray] = [] for image in images: - pil_image = context.services.images.get_pil_image(image.image_name) + pil_image = context.images.get_pil(image.image_name) pil_image = pil_image.convert("RGB") tile_np_images.append(np.array(pil_image)) @@ -288,18 +287,5 @@ def invoke(self, context: InvocationContext) -> ImageOutput: # Convert into a PIL image and save pil_image = Image.fromarray(np_image) - image_dto = context.services.images.create( - image=pil_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + image_dto = context.images.save(image=pil_image) + return ImageOutput.build(image_dto) diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index 2cab279a9fc..ef174809860 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -8,13 +8,13 @@ from PIL import Image from pydantic import ConfigDict -from invokeai.app.invocations.primitives import ImageField, ImageOutput -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin +from invokeai.app.invocations.fields import ImageField +from invokeai.app.invocations.primitives import ImageOutput from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN from invokeai.backend.util.devices import choose_torch_device -from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .baseinvocation import BaseInvocation, invocation from .fields import InputField, WithMetadata # TODO: Populate this from disk? @@ -30,7 +30,7 @@ from torch import mps -@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.0") +@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.1") class ESRGANInvocation(BaseInvocation, WithMetadata): """Upscales an image using RealESRGAN.""" @@ -42,9 +42,9 @@ class ESRGANInvocation(BaseInvocation, WithMetadata): model_config = ConfigDict(protected_namespaces=()) - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) - models_path = context.services.configuration.models_path + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) + models_path = context.config.get().models_path rrdbnet_model = None netscale = None @@ -88,7 +88,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: netscale = 2 else: msg = f"Invalid RealESRGAN model: {self.model_name}" - context.services.logger.error(msg) + context.logger.error(msg) raise ValueError(msg) esrgan_model_path = Path(f"core/upscaling/realesrgan/{self.model_name}") @@ -111,19 +111,6 @@ def invoke(self, context: InvocationContext) -> ImageOutput: if choose_torch_device() == torch.device("mps"): mps.empty_cache() - image_dto = context.services.images.create( - image=pil_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=pil_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index e9365f33495..ad08ae03956 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -55,7 +55,7 @@ def emit_generator_progress( queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str, - node: dict, + node_id: str, source_node_id: str, progress_image: Optional[ProgressImage], step: int, @@ -70,7 +70,7 @@ def emit_generator_progress( "queue_item_id": queue_item_id, "queue_batch_id": queue_batch_id, "graph_execution_state_id": graph_execution_state_id, - "node_id": node.get("id"), + "node_id": node_id, "source_node_id": source_node_id, "progress_image": progress_image.model_dump() if progress_image is not None else None, "step": step, diff --git a/invokeai/app/services/invocation_processor/invocation_processor_default.py b/invokeai/app/services/invocation_processor/invocation_processor_default.py index 54342c0da13..d2ebe235e63 100644 --- a/invokeai/app/services/invocation_processor/invocation_processor_default.py +++ b/invokeai/app/services/invocation_processor/invocation_processor_default.py @@ -5,11 +5,11 @@ from typing import Optional import invokeai.backend.util.logging as logger -from invokeai.app.invocations.baseinvocation import InvocationContext from invokeai.app.services.invocation_queue.invocation_queue_common import InvocationQueueItem from invokeai.app.services.invocation_stats.invocation_stats_common import ( GESStatsNotFoundError, ) +from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context from invokeai.app.util.profiler import Profiler from ..invoker import Invoker @@ -131,16 +131,20 @@ def stats_cleanup(graph_execution_state_id: str) -> None: # which handles a few things: # - nodes that require a value, but get it only from a connection # - referencing the invocation cache instead of executing the node - outputs = invocation.invoke_internal( - InvocationContext( - services=self.__invoker.services, - graph_execution_state_id=graph_execution_state.id, - queue_item_id=queue_item.session_queue_item_id, - queue_id=queue_item.session_queue_id, - queue_batch_id=queue_item.session_queue_batch_id, - workflow=queue_item.workflow, - ) + context_data = InvocationContextData( + invocation=invocation, + session_id=graph_id, + workflow=queue_item.workflow, + source_node_id=source_node_id, + queue_id=queue_item.session_queue_id, + queue_item_id=queue_item.session_queue_item_id, + batch_id=queue_item.session_queue_batch_id, + ) + context = build_invocation_context( + services=self.__invoker.services, + context_data=context_data, ) + outputs = invocation.invoke_internal(context=context, services=self.__invoker.services) # Check queue to see if this is canceled, and skip if so if self.__invoker.services.queue.is_canceled(graph_execution_state.id): diff --git a/invokeai/app/services/model_manager/model_manager_base.py b/invokeai/app/services/model_manager/model_manager_base.py index 4c2fc4c085c..a9b53ae2242 100644 --- a/invokeai/app/services/model_manager/model_manager_base.py +++ b/invokeai/app/services/model_manager/model_manager_base.py @@ -5,11 +5,12 @@ from abc import ABC, abstractmethod from logging import Logger from pathlib import Path -from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union +from typing import Callable, List, Literal, Optional, Tuple, Union from pydantic import Field from invokeai.app.services.config.config_default import InvokeAIAppConfig +from invokeai.app.services.shared.invocation_context import InvocationContextData from invokeai.backend.model_management import ( AddModelResult, BaseModelType, @@ -21,9 +22,6 @@ ) from invokeai.backend.model_management.model_cache import CacheStats -if TYPE_CHECKING: - from invokeai.app.invocations.baseinvocation import BaseInvocation, InvocationContext - class ModelManagerServiceBase(ABC): """Responsible for managing models on disk and in memory""" @@ -49,8 +47,7 @@ def get_model( base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None, - node: Optional[BaseInvocation] = None, - context: Optional[InvocationContext] = None, + context_data: Optional[InvocationContextData] = None, ) -> ModelInfo: """Retrieve the indicated model with name and type. submodel can be used to get a part (such as the vae) diff --git a/invokeai/app/services/model_manager/model_manager_common.py b/invokeai/app/services/model_manager/model_manager_common.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index cdb3e59a91c..b641dd3f1ed 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -11,6 +11,8 @@ from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.shared.invocation_context import InvocationContextData from invokeai.backend.model_management import ( AddModelResult, BaseModelType, @@ -30,7 +32,7 @@ from .model_manager_base import ModelManagerServiceBase if TYPE_CHECKING: - from invokeai.app.invocations.baseinvocation import InvocationContext + pass # simple implementation @@ -86,13 +88,16 @@ def __init__( ) logger.info("Model manager service initialized") + def start(self, invoker: Invoker) -> None: + self._invoker: Optional[Invoker] = invoker + def get_model( self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None, - context: Optional[InvocationContext] = None, + context_data: Optional[InvocationContextData] = None, ) -> ModelInfo: """ Retrieve the indicated model. submodel can be used to get a @@ -100,9 +105,9 @@ def get_model( """ # we can emit model loading events if we are executing with access to the invocation context - if context: + if context_data is not None: self._emit_load_event( - context=context, + context_data=context_data, model_name=model_name, base_model=base_model, model_type=model_type, @@ -116,9 +121,9 @@ def get_model( submodel, ) - if context: + if context_data is not None: self._emit_load_event( - context=context, + context_data=context_data, model_name=model_name, base_model=base_model, model_type=model_type, @@ -263,22 +268,25 @@ def commit(self, conf_file: Optional[Path] = None): def _emit_load_event( self, - context: InvocationContext, + context_data: InvocationContextData, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None, model_info: Optional[ModelInfo] = None, ): - if context.services.queue.is_canceled(context.graph_execution_state_id): + if self._invoker is None: + return + + if self._invoker.services.queue.is_canceled(context_data.session_id): raise CanceledException() if model_info: - context.services.events.emit_model_load_completed( - queue_id=context.queue_id, - queue_item_id=context.queue_item_id, - queue_batch_id=context.queue_batch_id, - graph_execution_state_id=context.graph_execution_state_id, + self._invoker.services.events.emit_model_load_completed( + queue_id=context_data.queue_id, + queue_item_id=context_data.queue_item_id, + queue_batch_id=context_data.batch_id, + graph_execution_state_id=context_data.session_id, model_name=model_name, base_model=base_model, model_type=model_type, @@ -286,11 +294,11 @@ def _emit_load_event( model_info=model_info, ) else: - context.services.events.emit_model_load_started( - queue_id=context.queue_id, - queue_item_id=context.queue_item_id, - queue_batch_id=context.queue_batch_id, - graph_execution_state_id=context.graph_execution_state_id, + self._invoker.services.events.emit_model_load_started( + queue_id=context_data.queue_id, + queue_item_id=context_data.queue_item_id, + queue_batch_id=context_data.batch_id, + graph_execution_state_id=context_data.session_id, model_name=model_name, base_model=base_model, model_type=model_type, diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index ba05b050c5b..c0699eb96bb 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -13,7 +13,6 @@ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) @@ -202,7 +201,7 @@ class GraphInvocation(BaseInvocation): # TODO: figure out how to create a default here graph: "Graph" = InputField(description="The graph to run", default=None) - def invoke(self, context: InvocationContext) -> GraphInvocationOutput: + def invoke(self, context) -> GraphInvocationOutput: """Invoke with provided services and return outputs.""" return GraphInvocationOutput() @@ -228,7 +227,7 @@ class IterateInvocation(BaseInvocation): ) index: int = InputField(description="The index, will be provided on executed iterators", default=0, ui_hidden=True) - def invoke(self, context: InvocationContext) -> IterateInvocationOutput: + def invoke(self, context) -> IterateInvocationOutput: """Produces the outputs as values""" return IterateInvocationOutput(item=self.collection[self.index], index=self.index, total=len(self.collection)) @@ -255,7 +254,7 @@ class CollectInvocation(BaseInvocation): description="The collection, will be provided on execution", default=[], ui_hidden=True ) - def invoke(self, context: InvocationContext) -> CollectInvocationOutput: + def invoke(self, context) -> CollectInvocationOutput: """Invoke with provided services and return outputs.""" return CollectInvocationOutput(collection=copy.copy(self.collection)) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index c0aaac54f87..b68e521c73f 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -6,8 +6,7 @@ from pydantic import ConfigDict from torch import Tensor -from invokeai.app.invocations.compel import ConditioningFieldData -from invokeai.app.invocations.fields import MetadataField, WithMetadata +from invokeai.app.invocations.fields import ConditioningFieldData, MetadataField, WithMetadata from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin from invokeai.app.services.images.images_common import ImageDTO @@ -245,13 +244,15 @@ def save(conditioning_data: ConditioningFieldData) -> str: ) return name - def get(conditioning_name: str) -> Tensor: + def get(conditioning_name: str) -> ConditioningFieldData: """ Gets conditioning data by name. :param conditioning_name: The name of the conditioning data to get. """ - return services.latents.get(conditioning_name) + # TODO(sm): We are (ab)using the latents storage service as a general pickle storage + # service, but it is typed as returning tensors, so we need to ignore the type here. + return services.latents.get(conditioning_name) # type: ignore [return-value] self.save = save self.get = get diff --git a/invokeai/app/util/step_callback.py b/invokeai/app/util/step_callback.py index 5cc3caa9ba5..d83b380d95d 100644 --- a/invokeai/app/util/step_callback.py +++ b/invokeai/app/util/step_callback.py @@ -1,25 +1,18 @@ -from typing import Protocol +from typing import TYPE_CHECKING import torch from PIL import Image -from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException, ProgressImage -from invokeai.app.services.invocation_queue.invocation_queue_base import InvocationQueueABC -from invokeai.app.services.shared.invocation_context import InvocationContextData from ...backend.model_management.models import BaseModelType from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.util.util import image_to_dataURL - -class StepCallback(Protocol): - def __call__( - self, - intermediate_state: PipelineIntermediateState, - base_model: BaseModelType, - ) -> None: - ... +if TYPE_CHECKING: + from invokeai.app.services.events.events_base import EventServiceBase + from invokeai.app.services.invocation_queue.invocation_queue_base import InvocationQueueABC + from invokeai.app.services.shared.invocation_context import InvocationContextData def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None): @@ -38,11 +31,11 @@ def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix= def stable_diffusion_step_callback( - context_data: InvocationContextData, + context_data: "InvocationContextData", intermediate_state: PipelineIntermediateState, base_model: BaseModelType, - invocation_queue: InvocationQueueABC, - events: EventServiceBase, + invocation_queue: "InvocationQueueABC", + events: "EventServiceBase", ) -> None: if invocation_queue.is_canceled(context_data.session_id): raise CanceledException From 4aa7bee4b9c4235e4118092be69aa5bf88237722 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 13 Jan 2024 23:23:27 +1100 Subject: [PATCH 005/100] docs: update INVOCATIONS.md --- docs/contributing/INVOCATIONS.md | 97 ++++++++++++-------------------- 1 file changed, 36 insertions(+), 61 deletions(-) diff --git a/docs/contributing/INVOCATIONS.md b/docs/contributing/INVOCATIONS.md index 124589f44ce..5d9a3690bad 100644 --- a/docs/contributing/INVOCATIONS.md +++ b/docs/contributing/INVOCATIONS.md @@ -9,11 +9,15 @@ complex functionality. ## Invocations Directory -InvokeAI Nodes can be found in the `invokeai/app/invocations` directory. These can be used as examples to create your own nodes. +InvokeAI Nodes can be found in the `invokeai/app/invocations` directory. These +can be used as examples to create your own nodes. -New nodes should be added to a subfolder in `nodes` direction found at the root level of the InvokeAI installation location. Nodes added to this folder will be able to be used upon application startup. +New nodes should be added to a subfolder in `nodes` direction found at the root +level of the InvokeAI installation location. Nodes added to this folder will be +able to be used upon application startup. + +Example `nodes` subfolder structure: -Example `nodes` subfolder structure: ```py ├── __init__.py # Invoke-managed custom node loader │ @@ -30,14 +34,14 @@ Example `nodes` subfolder structure: └── fancy_node.py ``` -Each node folder must have an `__init__.py` file that imports its nodes. Only nodes imported in the `__init__.py` file are loaded. - See the README in the nodes folder for more examples: +Each node folder must have an `__init__.py` file that imports its nodes. Only +nodes imported in the `__init__.py` file are loaded. See the README in the nodes +folder for more examples: ```py from .cool_node import CoolInvocation ``` - ## Creating A New Invocation In order to understand the process of creating a new Invocation, let us actually @@ -131,7 +135,6 @@ from invokeai.app.invocations.primitives import ImageField class ResizeInvocation(BaseInvocation): '''Resizes an image''' - # Inputs image: ImageField = InputField(description="The input image") width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image") height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image") @@ -167,12 +170,11 @@ from invokeai.app.invocations.primitives import ImageField class ResizeInvocation(BaseInvocation): '''Resizes an image''' - # Inputs image: ImageField = InputField(description="The input image") width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image") height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image") - def invoke(self, context: InvocationContext): + def invoke(self, context): pass ``` @@ -197,12 +199,11 @@ from invokeai.app.invocations.image import ImageOutput class ResizeInvocation(BaseInvocation): '''Resizes an image''' - # Inputs image: ImageField = InputField(description="The input image") width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image") height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image") - def invoke(self, context: InvocationContext) -> ImageOutput: + def invoke(self, context) -> ImageOutput: pass ``` @@ -228,31 +229,18 @@ class ResizeInvocation(BaseInvocation): width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image") height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image") - def invoke(self, context: InvocationContext) -> ImageOutput: - # Load the image using InvokeAI's predefined Image Service. Returns the PIL image. - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + # Load the input image as a PIL image + image = context.images.get_pil(self.image.image_name) - # Resizing the image + # Resize the image resized_image = image.resize((self.width, self.height)) - # Save the image using InvokeAI's predefined Image Service. Returns the prepared PIL image. - output_image = context.services.images.create( - image=resized_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - ) - - # Returning the Image - return ImageOutput( - image=ImageField( - image_name=output_image.image_name, - ), - width=output_image.width, - height=output_image.height, - ) + # Save the image + image_dto = context.images.save(image=resized_image) + + # Return an ImageOutput + return ImageOutput.build(image_dto) ``` **Note:** Do not be overwhelmed by the `ImageOutput` process. InvokeAI has a @@ -343,27 +331,25 @@ class ImageColorStringOutput(BaseInvocationOutput): That's all there is to it. - +Custom fields only support connection inputs in the Workflow Editor. From ac2eb16a658e4e119a6d17fecaaa899ea13eff0f Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 13 Jan 2024 23:23:38 +1100 Subject: [PATCH 006/100] tests: fix tests for new invocation context --- tests/aa_nodes/test_graph_execution_state.py | 8 +------- tests/aa_nodes/test_nodes.py | 17 ++++++++--------- 2 files changed, 9 insertions(+), 16 deletions(-) diff --git a/tests/aa_nodes/test_graph_execution_state.py b/tests/aa_nodes/test_graph_execution_state.py index fab1fa4598f..9cc30e43e11 100644 --- a/tests/aa_nodes/test_graph_execution_state.py +++ b/tests/aa_nodes/test_graph_execution_state.py @@ -21,7 +21,6 @@ from invokeai.app.services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService -from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID from invokeai.app.services.shared.graph import ( CollectInvocation, Graph, @@ -86,12 +85,7 @@ def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[B print(f"invoking {n.id}: {type(n)}") o = n.invoke( InvocationContext( - queue_batch_id="1", - queue_item_id=1, - queue_id=DEFAULT_QUEUE_ID, - services=services, - graph_execution_state_id="1", - workflow=None, + conditioning=None, config=None, data=None, images=None, latents=None, logger=None, models=None, util=None ) ) g.complete(n.id, o) diff --git a/tests/aa_nodes/test_nodes.py b/tests/aa_nodes/test_nodes.py index e71daad3f3a..559457c0e11 100644 --- a/tests/aa_nodes/test_nodes.py +++ b/tests/aa_nodes/test_nodes.py @@ -3,7 +3,6 @@ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) @@ -21,7 +20,7 @@ class ListPassThroughInvocationOutput(BaseInvocationOutput): class ListPassThroughInvocation(BaseInvocation): collection: list[ImageField] = InputField(default=[]) - def invoke(self, context: InvocationContext) -> ListPassThroughInvocationOutput: + def invoke(self, context) -> ListPassThroughInvocationOutput: return ListPassThroughInvocationOutput(collection=self.collection) @@ -34,13 +33,13 @@ class PromptTestInvocationOutput(BaseInvocationOutput): class PromptTestInvocation(BaseInvocation): prompt: str = InputField(default="") - def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput: + def invoke(self, context) -> PromptTestInvocationOutput: return PromptTestInvocationOutput(prompt=self.prompt) @invocation("test_error", version="1.0.0") class ErrorInvocation(BaseInvocation): - def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput: + def invoke(self, context) -> PromptTestInvocationOutput: raise Exception("This invocation is supposed to fail") @@ -54,7 +53,7 @@ class TextToImageTestInvocation(BaseInvocation): prompt: str = InputField(default="") prompt2: str = InputField(default="") - def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput: + def invoke(self, context) -> ImageTestInvocationOutput: return ImageTestInvocationOutput(image=ImageField(image_name=self.id)) @@ -63,7 +62,7 @@ class ImageToImageTestInvocation(BaseInvocation): prompt: str = InputField(default="") image: Union[ImageField, None] = InputField(default=None) - def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput: + def invoke(self, context) -> ImageTestInvocationOutput: return ImageTestInvocationOutput(image=ImageField(image_name=self.id)) @@ -76,7 +75,7 @@ class PromptCollectionTestInvocationOutput(BaseInvocationOutput): class PromptCollectionTestInvocation(BaseInvocation): collection: list[str] = InputField() - def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput: + def invoke(self, context) -> PromptCollectionTestInvocationOutput: return PromptCollectionTestInvocationOutput(collection=self.collection.copy()) @@ -89,7 +88,7 @@ class AnyTypeTestInvocationOutput(BaseInvocationOutput): class AnyTypeTestInvocation(BaseInvocation): value: Any = InputField(default=None) - def invoke(self, context: InvocationContext) -> AnyTypeTestInvocationOutput: + def invoke(self, context) -> AnyTypeTestInvocationOutput: return AnyTypeTestInvocationOutput(value=self.value) @@ -97,7 +96,7 @@ def invoke(self, context: InvocationContext) -> AnyTypeTestInvocationOutput: class PolymorphicStringTestInvocation(BaseInvocation): value: Union[str, list[str]] = InputField(default="") - def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput: + def invoke(self, context) -> PromptCollectionTestInvocationOutput: if isinstance(self.value, str): return PromptCollectionTestInvocationOutput(collection=[self.value]) return PromptCollectionTestInvocationOutput(collection=self.value) From 8baf3f78a29f5baf94753bfd11b5ddd19e7b0f63 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 14 Jan 2024 00:05:15 +1100 Subject: [PATCH 007/100] feat(nodes): tidy `invocation_context.py`, improve comments --- .../app/services/shared/invocation_context.py | 115 ++++++++++++------ 1 file changed, 80 insertions(+), 35 deletions(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index b68e521c73f..7961c011aff 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from enum import Enum from typing import TYPE_CHECKING, Optional from PIL.Image import Image @@ -37,6 +36,9 @@ When a node executes, a fresh `InvocationContext` is built for it, ensuring nodes cannot interfere with each other. +Many of the wrappers have the same signature as the methods they wrap. This allows us to write +user-facing docstrings and not need to go and update the internal services to match. + Note: The docstrings are in weird places, but that's where they must be to get IDEs to see them. """ @@ -44,12 +46,19 @@ @dataclass(frozen=True) class InvocationContextData: invocation: "BaseInvocation" + """The invocation that is being executed.""" session_id: str + """The session that is being executed.""" queue_id: str + """The queue in which the session is being executed.""" source_node_id: str + """The ID of the node from which the currently executing invocation was prepared.""" queue_item_id: int + """The ID of the queue item that is being executed.""" batch_id: str + """The ID of the batch that is being executed.""" workflow: Optional[WorkflowWithoutID] = None + """The workflow associated with this queue item, if any.""" class LoggerInterface: @@ -103,14 +112,15 @@ def save( """ Saves an image, returning its DTO. - If the current queue item has a workflow, it is automatically saved with the image. + If the current queue item has a workflow or metadata, it is automatically saved with the image. :param image: The image to save, as a PIL image. :param board_id: The board ID to add the image to, if it should be added. - :param image_category: The category of the image. Only the GENERAL category is added to the gallery. - :param metadata: The metadata to save with the image, if it should have any. If the invocation inherits \ - from `WithMetadata`, that metadata will be used automatically. Provide this only if you want to \ - override or provide metadata manually. + :param image_category: The category of the image. Only the GENERAL category is added \ + to the gallery. + :param metadata: The metadata to save with the image, if it should have any. If the \ + invocation inherits from `WithMetadata`, that metadata will be used automatically. \ + **Use this only if you want to override or provide metadata manually!** """ # If the invocation inherits metadata, use that. Else, use the metadata passed in. @@ -186,14 +196,6 @@ def update( self.update = update -class LatentsKind(str, Enum): - IMAGE = "image" - NOISE = "noise" - MASK = "mask" - MASKED_IMAGE = "masked_image" - OTHER = "other" - - class LatentsInterface: def __init__( self, @@ -206,6 +208,22 @@ def save(tensor: Tensor) -> str: :param tensor: The latents tensor to save. """ + + # Previously, we added a suffix indicating the type of Tensor we were saving, e.g. + # "mask", "noise", "masked_latents", etc. + # + # Retaining that capability in this wrapper would require either many different methods + # to save latents, or extra args for this method. Instead of complicating the API, we + # will use the same naming scheme for all latents. + # + # This has a very minor impact as we don't use them after a session completes. + + # Previously, invocations chose the name for their latents. This is a bit risky, so we + # will generate a name for them instead. We use a uuid to ensure the name is unique. + # + # Because the name of the latents file will includes the session and invocation IDs, + # we don't need to worry about collisions. A truncated UUIDv4 is fine. + name = f"{context_data.session_id}__{context_data.invocation.id}__{uuid_string()[:7]}" services.latents.save( name=name, @@ -231,12 +249,21 @@ def __init__( services: InvocationServices, context_data: InvocationContextData, ) -> None: + # TODO(psyche): We are (ab)using the latents storage service as a general pickle storage + # service, but it is typed to work with Tensors only. We have to fudge the types here. + def save(conditioning_data: ConditioningFieldData) -> str: """ Saves a conditioning data object, returning its name. :param conditioning_data: The conditioning data to save. """ + + # Conditioning data is *not* a Tensor, so we will suffix it to indicate this. + # + # See comment for `LatentsInterface.save` for more info about this method (it's very + # similar). + name = f"{context_data.session_id}__{context_data.invocation.id}__{uuid_string()[:7]}__conditioning" services.latents.save( name=name, @@ -250,9 +277,8 @@ def get(conditioning_name: str) -> ConditioningFieldData: :param conditioning_name: The name of the conditioning data to get. """ - # TODO(sm): We are (ab)using the latents storage service as a general pickle storage - # service, but it is typed as returning tensors, so we need to ignore the type here. - return services.latents.get(conditioning_name) # type: ignore [return-value] + + return services.latents.get(conditioning_name) # type: ignore [return-value] self.save = save self.get = get @@ -281,6 +307,17 @@ def load( :param model_type: The type of the model to get. :param submodel: The submodel of the model to get. """ + + # During this call, the model manager emits events with model loading status. The model + # manager itself has access to the events services, but does not have access to the + # required metadata for the events. + # + # For example, it needs access to the node's ID so that the events can be associated + # with the execution of a specific node. + # + # While this is available within the node, it's tedious to need to pass it in on every + # call. We can avoid that by wrapping the method here. + return services.model_manager.get_model( model_name, base_model, model_type, submodel, context_data=context_data ) @@ -306,8 +343,11 @@ def get() -> InvokeAIAppConfig: """ Gets the app's config. """ - # The config can be changed at runtime. We don't want nodes doing this, so we make a - # frozen copy.. + + # The config can be changed at runtime. + # + # We don't want nodes doing this, so we make a frozen copy. + config = services.configuration.get_config() frozen_config = config.model_copy(update={"model_config": ConfigDict(frozen=True)}) return frozen_config @@ -330,6 +370,12 @@ def sd_step_callback( :param intermediate_state: The intermediate state of the diffusion pipeline. :param base_model: The base model for the current denoising step. """ + + # The step callback needs access to the events and the invocation queue services, but this + # represents a dangerous level of access. + # + # We wrap the step callback so that nodes do not have direct access to these services. + stable_diffusion_step_callback( context_data=context_data, intermediate_state=intermediate_state, @@ -343,36 +389,36 @@ def sd_step_callback( class InvocationContext: """ - The invocation context provides access to various services and data about the current invocation. + The `InvocationContext` provides access to various services and data for the current invocation. """ def __init__( self, images: ImagesInterface, latents: LatentsInterface, + conditioning: ConditioningInterface, models: ModelsInterface, - config: ConfigInterface, logger: LoggerInterface, - data: InvocationContextData, + config: ConfigInterface, util: UtilInterface, - conditioning: ConditioningInterface, + data: InvocationContextData, ) -> None: self.images = images - "Provides methods to save, get and update images and their metadata." - self.logger = logger - "Provides access to the app logger." + """Provides methods to save, get and update images and their metadata.""" self.latents = latents - "Provides methods to save and get latents tensors, including image, noise, masks, and masked images." + """Provides methods to save and get latents tensors, including image, noise, masks, and masked images.""" self.conditioning = conditioning - "Provides methods to save and get conditioning data." + """Provides methods to save and get conditioning data.""" self.models = models - "Provides methods to check if a model exists, get a model, and get a model's info." + """Provides methods to check if a model exists, get a model, and get a model's info.""" + self.logger = logger + """Provides access to the app logger.""" self.config = config - "Provides access to the app's config." - self.data = data - "Provides data about the current queue item and invocation." + """Provides access to the app's config.""" self.util = util - "Provides utility methods." + """Provides utility methods.""" + self.data = data + """Provides data about the current queue item and invocation.""" def build_invocation_context( @@ -380,8 +426,7 @@ def build_invocation_context( context_data: InvocationContextData, ) -> InvocationContext: """ - Builds the invocation context. This is a wrapper around the invocation services that provides - a more convenient (and less dangerous) interface for nodes to use. + Builds the invocation context for a specific invocation execution. :param invocation_services: The invocation services to wrap. :param invocation_context_data: The invocation context data. From 183c9c4799baa0f4c530b7190752602828ff3cb4 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 14 Jan 2024 00:34:56 +1100 Subject: [PATCH 008/100] chore: ruff --- invokeai/app/invocations/baseinvocation.py | 1 - invokeai/app/invocations/onnx.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index c4aed1fac5a..df0596c9a15 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -22,7 +22,6 @@ Input, InputFieldJSONSchemaExtra, MetadataField, - logger, ) from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.shared.invocation_context import InvocationContext diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py index 3f8e6669ab8..a1e318a3802 100644 --- a/invokeai/app/invocations/onnx.py +++ b/invokeai/app/invocations/onnx.py @@ -318,7 +318,7 @@ def dispatch_progress( name = f"{context.graph_execution_state_id}__{self.id}" context.services.latents.save(name, latents) - return build_latents_output(latents_name=name, latents=torch.from_numpy(latents)) + # return build_latents_output(latents_name=name, latents=torch.from_numpy(latents)) # Latent to image From 5c7ed24aab1bff551aa8dde8ff9b8671b4e1e3f5 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 14 Jan 2024 20:16:51 +1100 Subject: [PATCH 009/100] feat(nodes): restore previous invocation context methods with deprecation warnings --- .../app/services/shared/invocation_context.py | 117 +++++++++++++++++- pyproject.toml | 1 + 2 files changed, 117 insertions(+), 1 deletion(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 7961c011aff..023274d49fa 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Optional +from deprecated import deprecated from PIL.Image import Image from pydantic import ConfigDict from torch import Tensor @@ -365,7 +366,7 @@ def sd_step_callback( The step callback emits a progress event with the current step, the total number of steps, a preview image, and some other internal metadata. - This should be called after each step of the diffusion process. + This should be called after each denoising step. :param intermediate_state: The intermediate state of the diffusion pipeline. :param base_model: The base model for the current denoising step. @@ -387,6 +388,30 @@ def sd_step_callback( self.sd_step_callback = sd_step_callback +deprecation_version = "3.7.0" +removed_version = "3.8.0" + + +def get_deprecation_reason(property_name: str, alternative: Optional[str] = None) -> str: + msg = f"{property_name} is deprecated as of v{deprecation_version}. It will be removed in v{removed_version}." + if alternative is not None: + msg += f" Use {alternative} instead." + msg += " See PLACEHOLDER_URL for details." + return msg + + +# Deprecation docstrings template. I don't think we can implement these programmatically with +# __doc__ because the IDE won't see them. + +""" +**DEPRECATED as of v3.7.0** + +PROPERTY_NAME will be removed in v3.8.0. Use ALTERNATIVE instead. See PLACEHOLDER_URL for details. + +OG_DOCSTRING +""" + + class InvocationContext: """ The `InvocationContext` provides access to various services and data for the current invocation. @@ -402,6 +427,7 @@ def __init__( config: ConfigInterface, util: UtilInterface, data: InvocationContextData, + services: InvocationServices, ) -> None: self.images = images """Provides methods to save, get and update images and their metadata.""" @@ -419,6 +445,94 @@ def __init__( """Provides utility methods.""" self.data = data """Provides data about the current queue item and invocation.""" + self.__services = services + + @property + @deprecated(version=deprecation_version, reason=get_deprecation_reason("`context.services`")) + def services(self) -> InvocationServices: + """ + **DEPRECATED as of v3.7.0** + + `context.services` will be removed in v3.8.0. See PLACEHOLDER_URL for details. + + The invocation services. + """ + return self.__services + + @property + @deprecated( + version=deprecation_version, + reason=get_deprecation_reason("`context.graph_execution_state_api`", "`context.data.session_id`"), + ) + def graph_execution_state_id(self) -> str: + """ + **DEPRECATED as of v3.7.0** + + `context.graph_execution_state_api` will be removed in v3.8.0. Use `context.data.session_id` instead. See PLACEHOLDER_URL for details. + + The ID of the session (aka graph execution state). + """ + return self.data.session_id + + @property + @deprecated( + version=deprecation_version, + reason=get_deprecation_reason("`context.queue_id`", "`context.data.queue_id`"), + ) + def queue_id(self) -> str: + """ + **DEPRECATED as of v3.7.0** + + `context.queue_id` will be removed in v3.8.0. Use `context.data.queue_id` instead. See PLACEHOLDER_URL for details. + + The ID of the queue. + """ + return self.data.queue_id + + @property + @deprecated( + version=deprecation_version, + reason=get_deprecation_reason("`context.queue_item_id`", "`context.data.queue_item_id`"), + ) + def queue_item_id(self) -> int: + """ + **DEPRECATED as of v3.7.0** + + `context.queue_item_id` will be removed in v3.8.0. Use `context.data.queue_item_id` instead. See PLACEHOLDER_URL for details. + + The ID of the queue item. + """ + return self.data.queue_item_id + + @property + @deprecated( + version=deprecation_version, + reason=get_deprecation_reason("`context.queue_batch_id`", "`context.data.batch_id`"), + ) + def queue_batch_id(self) -> str: + """ + **DEPRECATED as of v3.7.0** + + `context.queue_batch_id` will be removed in v3.8.0. Use `context.data.batch_id` instead. See PLACEHOLDER_URL for details. + + The ID of the batch. + """ + return self.data.batch_id + + @property + @deprecated( + version=deprecation_version, + reason=get_deprecation_reason("`context.workflow`", "`context.data.workflow`"), + ) + def workflow(self) -> Optional[WorkflowWithoutID]: + """ + **DEPRECATED as of v3.7.0** + + `context.workflow` will be removed in v3.8.0. Use `context.data.workflow` instead. See PLACEHOLDER_URL for details. + + The workflow associated with this queue item, if any. + """ + return self.data.workflow def build_invocation_context( @@ -449,6 +563,7 @@ def build_invocation_context( data=context_data, util=util, conditioning=conditioning, + services=services, ) return ctx diff --git a/pyproject.toml b/pyproject.toml index d063f1ad0ee..8d25ed20910 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,6 +66,7 @@ dependencies = [ "albumentations", "click", "datasets", + "Deprecated", "dnspython~=2.4.0", "dynamicprompts", "easing-functions", From 3ceee2b2b248bba93a49799e2dce4ac643dcf199 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 15 Jan 2024 09:37:05 +1100 Subject: [PATCH 010/100] tests: fix missing arg for InvocationContext --- tests/aa_nodes/test_graph_execution_state.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/aa_nodes/test_graph_execution_state.py b/tests/aa_nodes/test_graph_execution_state.py index 9cc30e43e11..3577a78ae2c 100644 --- a/tests/aa_nodes/test_graph_execution_state.py +++ b/tests/aa_nodes/test_graph_execution_state.py @@ -85,7 +85,15 @@ def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[B print(f"invoking {n.id}: {type(n)}") o = n.invoke( InvocationContext( - conditioning=None, config=None, data=None, images=None, latents=None, logger=None, models=None, util=None + conditioning=None, + config=None, + data=None, + images=None, + latents=None, + logger=None, + models=None, + util=None, + services=None, ) ) g.complete(n.id, o) From 3de43907113c316f48fd96bcd9afd773675a4a00 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 15 Jan 2024 10:41:25 +1100 Subject: [PATCH 011/100] feat(nodes): move `ConditioningFieldData` to `conditioning_data.py` --- invokeai/app/invocations/compel.py | 2 +- invokeai/app/invocations/fields.py | 9 +-------- invokeai/app/services/shared/invocation_context.py | 3 ++- .../stable_diffusion/diffusion/conditioning_data.py | 5 +++++ 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index b4496031bc4..94caf4128d2 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -5,7 +5,6 @@ from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment from invokeai.app.invocations.fields import ( - ConditioningFieldData, FieldDescriptions, Input, InputField, @@ -15,6 +14,7 @@ from invokeai.app.invocations.primitives import ConditioningOutput from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, + ConditioningFieldData, ExtraConditioningInfo, SDXLConditioningInfo, ) diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index 566babbb6b7..8879f760770 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -1,13 +1,11 @@ -from dataclasses import dataclass from enum import Enum -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, Optional, Tuple from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter from pydantic.fields import _Unset from pydantic_core import PydanticUndefined from invokeai.app.util.metaenum import MetaEnum -from invokeai.backend.stable_diffusion.diffusion.conditioning_data import BasicConditioningInfo from invokeai.backend.util.logging import InvokeAILogger logger = InvokeAILogger.get_logger() @@ -544,11 +542,6 @@ def tuple(self) -> Tuple[int, int, int, int]: return (self.r, self.g, self.b, self.a) -@dataclass -class ConditioningFieldData: - conditionings: List[BasicConditioningInfo] - - class ConditioningField(BaseModel): """A conditioning tensor primitive value""" diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 023274d49fa..3cf3952de0e 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -6,7 +6,7 @@ from pydantic import ConfigDict from torch import Tensor -from invokeai.app.invocations.fields import ConditioningFieldData, MetadataField, WithMetadata +from invokeai.app.invocations.fields import MetadataField, WithMetadata from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin from invokeai.app.services.images.images_common import ImageDTO @@ -17,6 +17,7 @@ from invokeai.backend.model_management.model_manager import ModelInfo from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData if TYPE_CHECKING: from invokeai.app.invocations.baseinvocation import BaseInvocation diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 3e38f9f78d5..0676555f7a9 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -32,6 +32,11 @@ def to(self, device, dtype=None): return self +@dataclass +class ConditioningFieldData: + conditionings: List[BasicConditioningInfo] + + @dataclass class SDXLConditioningInfo(BasicConditioningInfo): pooled_embeds: torch.Tensor From a7e23af9c607fd2abb43891e60d5d9bfbc4d5713 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 15 Jan 2024 10:48:33 +1100 Subject: [PATCH 012/100] feat(nodes): create invocation_api.py This is the public API for invocations. Everything a custom node might need should be re-exported from this file. --- .../controlnet_image_processors.py | 3 +- invokeai/app/invocations/facetools.py | 3 +- invokeai/app/invocations/image.py | 2 +- invokeai/app/invocations/infill.py | 4 +- invokeai/app/invocations/tiles.py | 3 +- invokeai/invocation_api/__init__.py | 109 ++++++++++++++++++ pyproject.toml | 1 + 7 files changed, 116 insertions(+), 9 deletions(-) create mode 100644 invokeai/invocation_api/__init__.py diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 3797722c93e..e993ceffde5 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -25,8 +25,7 @@ from PIL import Image from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator -from invokeai.app.invocations.baseinvocation import WithMetadata -from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField +from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, WithMetadata from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.backend.image_util.depth_anything import DepthAnythingDetector diff --git a/invokeai/app/invocations/facetools.py b/invokeai/app/invocations/facetools.py index 2c92e28cfe0..dad63089816 100644 --- a/invokeai/app/invocations/facetools.py +++ b/invokeai/app/invocations/facetools.py @@ -13,11 +13,10 @@ import invokeai.assets.fonts as font_assets from invokeai.app.invocations.baseinvocation import ( BaseInvocation, - WithMetadata, invocation, invocation_output, ) -from invokeai.app.invocations.fields import ImageField, InputField, OutputField +from invokeai.app.invocations.fields import ImageField, InputField, OutputField, WithMetadata from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 10ebd97ace3..3b8b0b4b80b 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -7,7 +7,6 @@ import numpy from PIL import Image, ImageChops, ImageFilter, ImageOps -from invokeai.app.invocations.baseinvocation import WithMetadata from invokeai.app.invocations.fields import ( BoardField, ColorField, @@ -15,6 +14,7 @@ ImageField, Input, InputField, + WithMetadata, ) from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index be51c8312f9..159bdb5f7ad 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -13,8 +13,8 @@ from invokeai.backend.image_util.lama import LaMA from invokeai.backend.image_util.patchmatch import PatchMatch -from .baseinvocation import BaseInvocation, WithMetadata, invocation -from .fields import InputField +from .baseinvocation import BaseInvocation, invocation +from .fields import InputField, WithMetadata from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES diff --git a/invokeai/app/invocations/tiles.py b/invokeai/app/invocations/tiles.py index dd34c3dc093..0b4c472696b 100644 --- a/invokeai/app/invocations/tiles.py +++ b/invokeai/app/invocations/tiles.py @@ -8,11 +8,10 @@ BaseInvocation, BaseInvocationOutput, Classification, - WithMetadata, invocation, invocation_output, ) -from invokeai.app.invocations.fields import ImageField, Input, InputField, OutputField +from invokeai.app.invocations.fields import ImageField, Input, InputField, OutputField, WithMetadata from invokeai.app.invocations.primitives import ImageOutput from invokeai.backend.tiles.tiles import ( calc_tiles_even_split, diff --git a/invokeai/invocation_api/__init__.py b/invokeai/invocation_api/__init__.py new file mode 100644 index 00000000000..e867ec3cc4e --- /dev/null +++ b/invokeai/invocation_api/__init__.py @@ -0,0 +1,109 @@ +""" +This file re-exports all the public API for invocations. This is the only file that should be imported by custom nodes. + +TODO(psyche): Do we want to dogfood this? +""" + +from invokeai.app.invocations.baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + invocation, + invocation_output, +) +from invokeai.app.invocations.fields import ( + BoardField, + ColorField, + ConditioningField, + DenoiseMaskField, + FieldDescriptions, + FieldKind, + ImageField, + Input, + InputField, + LatentsField, + MetadataField, + OutputField, + UIComponent, + UIType, + WithMetadata, + WithWorkflow, +) +from invokeai.app.invocations.primitives import ( + BooleanCollectionOutput, + BooleanOutput, + ColorCollectionOutput, + ColorOutput, + ConditioningCollectionOutput, + ConditioningOutput, + DenoiseMaskOutput, + FloatCollectionOutput, + FloatOutput, + ImageCollectionOutput, + ImageOutput, + IntegerCollectionOutput, + IntegerOutput, + LatentsCollectionOutput, + LatentsOutput, + StringCollectionOutput, + StringOutput, +) +from invokeai.app.services.image_records.image_records_common import ImageCategory +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( + BasicConditioningInfo, + ConditioningFieldData, + ExtraConditioningInfo, + SDXLConditioningInfo, +) + +__all__ = [ + # invokeai.app.invocations.baseinvocation + "BaseInvocation", + "BaseInvocationOutput", + "invocation", + "invocation_output", + # invokeai.app.services.shared.invocation_context + "InvocationContext", + # invokeai.app.invocations.fields + "BoardField", + "ColorField", + "ConditioningField", + "DenoiseMaskField", + "FieldDescriptions", + "FieldKind", + "ImageField", + "Input", + "InputField", + "LatentsField", + "MetadataField", + "OutputField", + "UIComponent", + "UIType", + "WithMetadata", + "WithWorkflow", + # invokeai.app.invocations.primitives + "BooleanCollectionOutput", + "BooleanOutput", + "ColorCollectionOutput", + "ColorOutput", + "ConditioningCollectionOutput", + "ConditioningOutput", + "DenoiseMaskOutput", + "FloatCollectionOutput", + "FloatOutput", + "ImageCollectionOutput", + "ImageOutput", + "IntegerCollectionOutput", + "IntegerOutput", + "LatentsCollectionOutput", + "LatentsOutput", + "StringCollectionOutput", + "StringOutput", + # invokeai.app.services.image_records.image_records_common + "ImageCategory", + # invokeai.backend.stable_diffusion.diffusion.conditioning_data + "BasicConditioningInfo", + "ConditioningFieldData", + "ExtraConditioningInfo", + "SDXLConditioningInfo", +] diff --git a/pyproject.toml b/pyproject.toml index 8d25ed20910..69958064c6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -170,6 +170,7 @@ version = { attr = "invokeai.version.__version__" } "invokeai.frontend.web.static*", "invokeai.configs*", "invokeai.app*", + "invokeai.invocation_api*", ] [tool.setuptools.package-data] From cc295a9f0a5e36fec2ddee522e5b990555cf803f Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 16 Jan 2024 19:19:49 +1100 Subject: [PATCH 013/100] feat: tweak pyright config --- pyproject.toml | 47 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 33 insertions(+), 14 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 69958064c6d..8b28375e291 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -284,17 +284,36 @@ module = [ #=== End: MyPy [tool.pyright] -include = [ - "invokeai/app/invocations/" -] -exclude = [ - "**/node_modules", - "**/__pycache__", - "invokeai/app/invocations/onnx.py", - "invokeai/app/api/routers/models.py", - "invokeai/app/services/invocation_stats/invocation_stats_default.py", - "invokeai/app/services/model_manager/model_manager_base.py", - "invokeai/app/services/model_manager/model_manager_default.py", - "invokeai/app/services/model_records/model_records_sql.py", - "invokeai/app/util/controlnet_utils.py", -] +# Start from strict mode +typeCheckingMode = "strict" +# This errors whenever an import is missing a type stub file - way too noisy +reportMissingTypeStubs = "none" +# These are the rest of the rules enabled by strict mode - enable them @ warning +reportConstantRedefinition = "warning" +reportDeprecated = "warning" +reportDuplicateImport = "warning" +reportIncompleteStub = "warning" +reportInconsistentConstructor = "warning" +reportInvalidStubStatement = "warning" +reportMatchNotExhaustive = "warning" +reportMissingParameterType = "warning" +reportMissingTypeArgument = "warning" +reportPrivateUsage = "warning" +reportTypeCommentUsage = "warning" +reportUnknownArgumentType = "warning" +reportUnknownLambdaType = "warning" +reportUnknownMemberType = "warning" +reportUnknownParameterType = "warning" +reportUnknownVariableType = "warning" +reportUnnecessaryCast = "warning" +reportUnnecessaryComparison = "warning" +reportUnnecessaryContains = "warning" +reportUnnecessaryIsInstance = "warning" +reportUnusedClass = "warning" +reportUnusedImport = "warning" +reportUnusedFunction = "warning" +reportUnusedVariable = "warning" +reportUntypedBaseClass = "warning" +reportUntypedClassDecorator = "warning" +reportUntypedFunctionDecorator = "warning" +reportUntypedNamedTuple = "warning" From ae421fb4ab3527e40b261eab12f87750db8d08a7 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 16 Jan 2024 20:02:38 +1100 Subject: [PATCH 014/100] feat(nodes): do not freeze InvocationContextData, prevents it from being subclassesd --- invokeai/app/services/shared/invocation_context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 3cf3952de0e..a849d6b17a2 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -45,7 +45,7 @@ """ -@dataclass(frozen=True) +@dataclass class InvocationContextData: invocation: "BaseInvocation" """The invocation that is being executed.""" From 483bdbcb9fbc612ee0d855bbc3a6023e4535e988 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 5 Feb 2024 17:16:35 +1100 Subject: [PATCH 015/100] fix(nodes): restore type annotations for `InvocationContext` --- docs/contributing/INVOCATIONS.md | 6 +-- invokeai/app/invocations/collections.py | 7 +-- invokeai/app/invocations/compel.py | 18 +++---- .../controlnet_image_processors.py | 5 +- invokeai/app/invocations/cv.py | 3 +- invokeai/app/invocations/facetools.py | 24 ++++----- invokeai/app/invocations/image.py | 51 ++++++++++--------- invokeai/app/invocations/infill.py | 11 ++-- invokeai/app/invocations/ip_adapter.py | 3 +- invokeai/app/invocations/latent.py | 18 +++---- invokeai/app/invocations/math.py | 21 ++++---- invokeai/app/invocations/metadata.py | 9 ++-- invokeai/app/invocations/model.py | 13 ++--- invokeai/app/invocations/noise.py | 3 +- invokeai/app/invocations/onnx.py | 8 +-- invokeai/app/invocations/param_easing.py | 5 +- invokeai/app/invocations/primitives.py | 31 +++++------ invokeai/app/invocations/prompt.py | 5 +- invokeai/app/invocations/sdxl.py | 5 +- invokeai/app/invocations/strings.py | 12 +++-- invokeai/app/invocations/t2i_adapter.py | 3 +- invokeai/app/invocations/tiles.py | 13 ++--- invokeai/app/invocations/upscale.py | 3 +- invokeai/app/services/shared/graph.py | 7 +-- tests/aa_nodes/test_nodes.py | 17 ++++--- 25 files changed, 158 insertions(+), 143 deletions(-) diff --git a/docs/contributing/INVOCATIONS.md b/docs/contributing/INVOCATIONS.md index 5d9a3690bad..ce1ee9e808a 100644 --- a/docs/contributing/INVOCATIONS.md +++ b/docs/contributing/INVOCATIONS.md @@ -174,7 +174,7 @@ class ResizeInvocation(BaseInvocation): width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image") height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image") - def invoke(self, context): + def invoke(self, context: InvocationContext): pass ``` @@ -203,7 +203,7 @@ class ResizeInvocation(BaseInvocation): width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image") height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: pass ``` @@ -229,7 +229,7 @@ class ResizeInvocation(BaseInvocation): width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image") height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: # Load the input image as a PIL image image = context.images.get_pil(self.image.image_name) diff --git a/invokeai/app/invocations/collections.py b/invokeai/app/invocations/collections.py index f5709b4ba36..e02291980f9 100644 --- a/invokeai/app/invocations/collections.py +++ b/invokeai/app/invocations/collections.py @@ -5,6 +5,7 @@ from pydantic import ValidationInfo, field_validator from invokeai.app.invocations.primitives import IntegerCollectionOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.misc import SEED_MAX from .baseinvocation import BaseInvocation, invocation @@ -27,7 +28,7 @@ def stop_gt_start(cls, v: int, info: ValidationInfo): raise ValueError("stop must be greater than start") return v - def invoke(self, context) -> IntegerCollectionOutput: + def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: return IntegerCollectionOutput(collection=list(range(self.start, self.stop, self.step))) @@ -45,7 +46,7 @@ class RangeOfSizeInvocation(BaseInvocation): size: int = InputField(default=1, gt=0, description="The number of values") step: int = InputField(default=1, description="The step of the range") - def invoke(self, context) -> IntegerCollectionOutput: + def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: return IntegerCollectionOutput( collection=list(range(self.start, self.start + (self.step * self.size), self.step)) ) @@ -72,6 +73,6 @@ class RandomRangeInvocation(BaseInvocation): description="The seed for the RNG (omit for random)", ) - def invoke(self, context) -> IntegerCollectionOutput: + def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: rng = np.random.default_rng(self.seed) return IntegerCollectionOutput(collection=list(rng.integers(low=self.low, high=self.high, size=self.size))) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 94caf4128d2..978c6dcb17f 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, List, Optional, Union +from typing import List, Optional, Union import torch from compel import Compel, ReturnedEmbeddingsType @@ -12,6 +12,7 @@ UIComponent, ) from invokeai.app.invocations.primitives import ConditioningOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, ConditioningFieldData, @@ -31,10 +32,7 @@ ) from .model import ClipField -if TYPE_CHECKING: - from invokeai.app.services.shared.invocation_context import InvocationContext - - # unconditioned: Optional[torch.Tensor] +# unconditioned: Optional[torch.Tensor] # class ConditioningAlgo(str, Enum): @@ -65,7 +63,7 @@ class CompelInvocation(BaseInvocation): ) @torch.no_grad() - def invoke(self, context) -> ConditioningOutput: + def invoke(self, context: InvocationContext) -> ConditioningOutput: tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump()) text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump()) @@ -148,7 +146,7 @@ def _lora_loader(): class SDXLPromptInvocationBase: def run_clip_compel( self, - context: "InvocationContext", + context: InvocationContext, clip_field: ClipField, prompt: str, get_pooled: bool, @@ -288,7 +286,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2") @torch.no_grad() - def invoke(self, context) -> ConditioningOutput: + def invoke(self, context: InvocationContext) -> ConditioningOutput: c1, c1_pooled, ec1 = self.run_clip_compel( context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True ) @@ -373,7 +371,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection) @torch.no_grad() - def invoke(self, context) -> ConditioningOutput: + def invoke(self, context: InvocationContext) -> ConditioningOutput: # TODO: if there will appear lora for refiner - write proper prefix c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "", zero_on_empty=False) @@ -418,7 +416,7 @@ class ClipSkipInvocation(BaseInvocation): clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP") skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers) - def invoke(self, context) -> ClipSkipInvocationOutput: + def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput: self.clip.skipped_layers += self.skipped_layers return ClipSkipInvocationOutput( clip=self.clip, diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index e993ceffde5..f8bdf14117c 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -28,6 +28,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, WithMetadata from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.util import validate_begin_end_step, validate_weights +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.image_util.depth_anything import DepthAnythingDetector from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector from invokeai.backend.model_management.models.base import BaseModelType @@ -119,7 +120,7 @@ def validate_begin_end_step_percent(self) -> "ControlNetInvocation": validate_begin_end_step(self.begin_step_percent, self.end_step_percent) return self - def invoke(self, context) -> ControlOutput: + def invoke(self, context: InvocationContext) -> ControlOutput: return ControlOutput( control=ControlField( image=self.image, @@ -143,7 +144,7 @@ def run_processor(self, image: Image.Image) -> Image.Image: # superclass just passes through image without processing return image - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: raw_image = context.images.get_pil(self.image.image_name) # image type should be PIL.PngImagePlugin.PngImageFile ? processed_image = self.run_processor(raw_image) diff --git a/invokeai/app/invocations/cv.py b/invokeai/app/invocations/cv.py index 375b18f9c58..1ebabf5e064 100644 --- a/invokeai/app/invocations/cv.py +++ b/invokeai/app/invocations/cv.py @@ -7,6 +7,7 @@ from invokeai.app.invocations.fields import ImageField from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from .baseinvocation import BaseInvocation, invocation from .fields import InputField, WithMetadata @@ -19,7 +20,7 @@ class CvInpaintInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to inpaint") mask: ImageField = InputField(description="The mask to use when inpainting") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) mask = context.images.get_pil(self.mask.image_name) diff --git a/invokeai/app/invocations/facetools.py b/invokeai/app/invocations/facetools.py index dad63089816..a1702d6517c 100644 --- a/invokeai/app/invocations/facetools.py +++ b/invokeai/app/invocations/facetools.py @@ -1,7 +1,7 @@ import math import re from pathlib import Path -from typing import TYPE_CHECKING, Optional, TypedDict +from typing import Optional, TypedDict import cv2 import numpy as np @@ -19,9 +19,7 @@ from invokeai.app.invocations.fields import ImageField, InputField, OutputField, WithMetadata from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory - -if TYPE_CHECKING: - from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.app.services.shared.invocation_context import InvocationContext @invocation_output("face_mask_output") @@ -176,7 +174,7 @@ def prepare_faces_list( def generate_face_box_mask( - context: "InvocationContext", + context: InvocationContext, minimum_confidence: float, x_offset: float, y_offset: float, @@ -275,7 +273,7 @@ def generate_face_box_mask( def extract_face( - context: "InvocationContext", + context: InvocationContext, image: ImageType, face: FaceResultData, padding: int, @@ -356,7 +354,7 @@ def extract_face( def get_faces_list( - context: "InvocationContext", + context: InvocationContext, image: ImageType, should_chunk: bool, minimum_confidence: float, @@ -458,7 +456,7 @@ class FaceOffInvocation(BaseInvocation, WithMetadata): description="Whether to bypass full image face detection and default to image chunking. Chunking will occur if no faces are found in the full image.", ) - def faceoff(self, context: "InvocationContext", image: ImageType) -> Optional[ExtractFaceData]: + def faceoff(self, context: InvocationContext, image: ImageType) -> Optional[ExtractFaceData]: all_faces = get_faces_list( context=context, image=image, @@ -485,7 +483,7 @@ def faceoff(self, context: "InvocationContext", image: ImageType) -> Optional[Ex return face_data - def invoke(self, context) -> FaceOffOutput: + def invoke(self, context: InvocationContext) -> FaceOffOutput: image = context.images.get_pil(self.image.image_name) result = self.faceoff(context=context, image=image) @@ -543,7 +541,7 @@ def validate_comma_separated_ints(cls, v) -> str: raise ValueError('Face IDs must be a comma-separated list of integers (e.g. "1,2,3")') return v - def facemask(self, context: "InvocationContext", image: ImageType) -> FaceMaskResult: + def facemask(self, context: InvocationContext, image: ImageType) -> FaceMaskResult: all_faces = get_faces_list( context=context, image=image, @@ -600,7 +598,7 @@ def facemask(self, context: "InvocationContext", image: ImageType) -> FaceMaskRe mask=mask_pil, ) - def invoke(self, context) -> FaceMaskOutput: + def invoke(self, context: InvocationContext) -> FaceMaskOutput: image = context.images.get_pil(self.image.image_name) result = self.facemask(context=context, image=image) @@ -633,7 +631,7 @@ class FaceIdentifierInvocation(BaseInvocation, WithMetadata): description="Whether to bypass full image face detection and default to image chunking. Chunking will occur if no faces are found in the full image.", ) - def faceidentifier(self, context: "InvocationContext", image: ImageType) -> ImageType: + def faceidentifier(self, context: InvocationContext, image: ImageType) -> ImageType: image = image.copy() all_faces = get_faces_list( @@ -674,7 +672,7 @@ def faceidentifier(self, context: "InvocationContext", image: ImageType) -> Imag return image - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) result_image = self.faceidentifier(context=context, image=image) diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 3b8b0b4b80b..7b74e4d96d4 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -18,6 +18,7 @@ ) from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark from invokeai.backend.image_util.safety_checker import SafetyChecker @@ -34,7 +35,7 @@ class ShowImageInvocation(BaseInvocation): image: ImageField = InputField(description="The image to show") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) image.show() @@ -62,7 +63,7 @@ class BlankImageInvocation(BaseInvocation, WithMetadata): mode: Literal["RGB", "RGBA"] = InputField(default="RGB", description="The mode of the image") color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color of the image") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = Image.new(mode=self.mode, size=(self.width, self.height), color=self.color.tuple()) image_dto = context.images.save(image=image) @@ -86,7 +87,7 @@ class ImageCropInvocation(BaseInvocation, WithMetadata): width: int = InputField(default=512, gt=0, description="The width of the crop rectangle") height: int = InputField(default=512, gt=0, description="The height of the crop rectangle") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) image_crop = Image.new(mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0)) @@ -125,7 +126,7 @@ class CenterPadCropInvocation(BaseInvocation): description="Number of pixels to pad/crop from the bottom (negative values crop inwards, positive values pad outwards)", ) - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) # Calculate and create new image dimensions @@ -161,7 +162,7 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata): y: int = InputField(default=0, description="The top y coordinate at which to paste the image") crop: bool = InputField(default=False, description="Crop to base image dimensions") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: base_image = context.images.get_pil(self.base_image.image_name) image = context.images.get_pil(self.image.image_name) mask = None @@ -201,7 +202,7 @@ class MaskFromAlphaInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to create the mask from") invert: bool = InputField(default=False, description="Whether or not to invert the mask") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) image_mask = image.split()[-1] @@ -226,7 +227,7 @@ class ImageMultiplyInvocation(BaseInvocation, WithMetadata): image1: ImageField = InputField(description="The first image to multiply") image2: ImageField = InputField(description="The second image to multiply") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image1 = context.images.get_pil(self.image1.image_name) image2 = context.images.get_pil(self.image2.image_name) @@ -253,7 +254,7 @@ class ImageChannelInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to get the channel from") channel: IMAGE_CHANNELS = InputField(default="A", description="The channel to get") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) channel_image = image.getchannel(self.channel) @@ -279,7 +280,7 @@ class ImageConvertInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to convert") mode: IMAGE_MODES = InputField(default="L", description="The mode to convert to") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) converted_image = image.convert(self.mode) @@ -304,7 +305,7 @@ class ImageBlurInvocation(BaseInvocation, WithMetadata): # Metadata blur_type: Literal["gaussian", "box"] = InputField(default="gaussian", description="The type of blur") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) blur = ( @@ -338,7 +339,7 @@ def pil_from_array(self, arr): def array_from_pil(self, img): return numpy.array(img) / 255 - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) mode = image.mode @@ -401,7 +402,7 @@ class ImageResizeInvocation(BaseInvocation, WithMetadata): height: int = InputField(default=512, gt=0, description="The height to resize to (px)") resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] @@ -434,7 +435,7 @@ class ImageScaleInvocation(BaseInvocation, WithMetadata): ) resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] @@ -465,7 +466,7 @@ class ImageLerpInvocation(BaseInvocation, WithMetadata): min: int = InputField(default=0, ge=0, le=255, description="The minimum output value") max: int = InputField(default=255, ge=0, le=255, description="The maximum output value") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) image_arr = numpy.asarray(image, dtype=numpy.float32) / 255 @@ -492,7 +493,7 @@ class ImageInverseLerpInvocation(BaseInvocation, WithMetadata): min: int = InputField(default=0, ge=0, le=255, description="The minimum input value") max: int = InputField(default=255, ge=0, le=255, description="The maximum input value") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) image_arr = numpy.asarray(image, dtype=numpy.float32) @@ -517,7 +518,7 @@ class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to check") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) logger = context.logger @@ -553,7 +554,7 @@ class ImageWatermarkInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to check") text: str = InputField(default="InvokeAI", description="Watermark text") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) new_image = InvisibleWatermark.add_watermark(image, self.text) image_dto = context.images.save(image=new_image) @@ -579,7 +580,7 @@ class MaskEdgeInvocation(BaseInvocation, WithMetadata): description="Second threshold for the hysteresis procedure in Canny edge detection" ) - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: mask = context.images.get_pil(self.image.image_name).convert("L") npimg = numpy.asarray(mask, dtype=numpy.uint8) @@ -613,7 +614,7 @@ class MaskCombineInvocation(BaseInvocation, WithMetadata): mask1: ImageField = InputField(description="The first mask to combine") mask2: ImageField = InputField(description="The second image to combine") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: mask1 = context.images.get_pil(self.mask1.image_name).convert("L") mask2 = context.images.get_pil(self.mask2.image_name).convert("L") @@ -642,7 +643,7 @@ class ColorCorrectInvocation(BaseInvocation, WithMetadata): mask: Optional[ImageField] = InputField(default=None, description="Mask to use when applying color-correction") mask_blur_radius: float = InputField(default=8, description="Mask blur radius") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: pil_init_mask = None if self.mask is not None: pil_init_mask = context.images.get_pil(self.mask.image_name).convert("L") @@ -741,7 +742,7 @@ class ImageHueAdjustmentInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to adjust") hue: int = InputField(default=0, description="The degrees by which to rotate the hue, 0-360") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: pil_image = context.images.get_pil(self.image.image_name) # Convert image to HSV color space @@ -831,7 +832,7 @@ class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata): channel: COLOR_CHANNELS = InputField(description="Which channel to adjust") offset: int = InputField(default=0, ge=-255, le=255, description="The amount to adjust the channel by") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: pil_image = context.images.get_pil(self.image.image_name) # extract the channel and mode from the input and reference tuple @@ -888,7 +889,7 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata): scale: float = InputField(default=1.0, ge=0.0, description="The amount to scale the channel by.") invert_channel: bool = InputField(default=False, description="Invert the channel after scaling") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: pil_image = context.images.get_pil(self.image.image_name) # extract the channel and mode from the input and reference tuple @@ -931,7 +932,7 @@ class SaveImageInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description=FieldDescriptions.image) board: BoardField = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct) - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) image_dto = context.images.save(image=image, board_id=self.board.board_id if self.board else None) @@ -953,7 +954,7 @@ class LinearUIOutputInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description=FieldDescriptions.image) board: Optional[BoardField] = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct) - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image_dto = context.images.get_dto(self.image.image_name) image_dto = context.images.update( diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index 159bdb5f7ad..b007edd9e42 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -8,6 +8,7 @@ from invokeai.app.invocations.fields import ColorField, ImageField from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.misc import SEED_MAX from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint from invokeai.backend.image_util.lama import LaMA @@ -129,7 +130,7 @@ class InfillColorInvocation(BaseInvocation, WithMetadata): description="The color to use to infill", ) - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) solid_bg = Image.new("RGBA", image.size, self.color.tuple()) @@ -155,7 +156,7 @@ class InfillTileInvocation(BaseInvocation, WithMetadata): description="The seed to use for tile generation (omit for random)", ) - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) infilled = tile_fill_missing(image.copy(), seed=self.seed, tile_size=self.tile_size) @@ -176,7 +177,7 @@ class InfillPatchMatchInvocation(BaseInvocation, WithMetadata): downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill") resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name).convert("RGBA") resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] @@ -213,7 +214,7 @@ class LaMaInfillInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to infill") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) infilled = infill_lama(image.copy()) @@ -229,7 +230,7 @@ class CV2InfillInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to infill") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) infilled = infill_cv2(image.copy()) diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index b836be04b58..845fcfa2848 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -13,6 +13,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.model_management.models.base import BaseModelType, ModelType from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id @@ -92,7 +93,7 @@ def validate_begin_end_step_percent(self): validate_begin_end_step(self.begin_step_percent, self.end_step_percent) return self - def invoke(self, context) -> IPAdapterOutput: + def invoke(self, context: InvocationContext) -> IPAdapterOutput: # Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model. ip_adapter_info = context.models.get_info( self.ip_adapter_model.model_name, self.ip_adapter_model.base_model, ModelType.IPAdapter diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 0127a6521e1..2cc84f80a73 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -3,7 +3,7 @@ import math from contextlib import ExitStack from functools import singledispatchmethod -from typing import TYPE_CHECKING, List, Literal, Optional, Union +from typing import List, Literal, Optional, Union import einops import numpy as np @@ -42,6 +42,7 @@ LatentsOutput, ) from invokeai.app.invocations.t2i_adapter import T2IAdapterField +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus from invokeai.backend.model_management.models import ModelType, SilenceWarnings @@ -70,9 +71,6 @@ from .controlnet_image_processors import ControlField from .model import ModelInfo, UNetField, VaeField -if TYPE_CHECKING: - from invokeai.app.services.shared.invocation_context import InvocationContext - if choose_torch_device() == torch.device("mps"): from torch import mps @@ -177,7 +175,7 @@ def invoke(self, context) -> DenoiseMaskOutput: def get_scheduler( - context: "InvocationContext", + context: InvocationContext, scheduler_info: ModelInfo, scheduler_name: str, seed: int, @@ -300,7 +298,7 @@ def ge_one(cls, v): def get_conditioning_data( self, - context: "InvocationContext", + context: InvocationContext, scheduler, unet, seed, @@ -369,7 +367,7 @@ def __init__(self): def prep_control_data( self, - context: "InvocationContext", + context: InvocationContext, control_input: Union[ControlField, List[ControlField]], latents_shape: List[int], exit_stack: ExitStack, @@ -442,7 +440,7 @@ def prep_control_data( def prep_ip_adapter_data( self, - context: "InvocationContext", + context: InvocationContext, ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]], conditioning_data: ConditioningData, exit_stack: ExitStack, @@ -509,7 +507,7 @@ def prep_ip_adapter_data( def run_t2i_adapters( self, - context: "InvocationContext", + context: InvocationContext, t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]], latents_shape: list[int], do_classifier_free_guidance: bool, @@ -618,7 +616,7 @@ def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_en return num_inference_steps, timesteps, init_timestep - def prep_inpaint_mask(self, context: "InvocationContext", latents): + def prep_inpaint_mask(self, context: InvocationContext, latents): if self.denoise_mask is None: return None, None diff --git a/invokeai/app/invocations/math.py b/invokeai/app/invocations/math.py index d2dbf049816..83a092be69e 100644 --- a/invokeai/app/invocations/math.py +++ b/invokeai/app/invocations/math.py @@ -7,6 +7,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, InputField from invokeai.app.invocations.primitives import FloatOutput, IntegerOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from .baseinvocation import BaseInvocation, invocation @@ -18,7 +19,7 @@ class AddInvocation(BaseInvocation): a: int = InputField(default=0, description=FieldDescriptions.num_1) b: int = InputField(default=0, description=FieldDescriptions.num_2) - def invoke(self, context) -> IntegerOutput: + def invoke(self, context: InvocationContext) -> IntegerOutput: return IntegerOutput(value=self.a + self.b) @@ -29,7 +30,7 @@ class SubtractInvocation(BaseInvocation): a: int = InputField(default=0, description=FieldDescriptions.num_1) b: int = InputField(default=0, description=FieldDescriptions.num_2) - def invoke(self, context) -> IntegerOutput: + def invoke(self, context: InvocationContext) -> IntegerOutput: return IntegerOutput(value=self.a - self.b) @@ -40,7 +41,7 @@ class MultiplyInvocation(BaseInvocation): a: int = InputField(default=0, description=FieldDescriptions.num_1) b: int = InputField(default=0, description=FieldDescriptions.num_2) - def invoke(self, context) -> IntegerOutput: + def invoke(self, context: InvocationContext) -> IntegerOutput: return IntegerOutput(value=self.a * self.b) @@ -51,7 +52,7 @@ class DivideInvocation(BaseInvocation): a: int = InputField(default=0, description=FieldDescriptions.num_1) b: int = InputField(default=0, description=FieldDescriptions.num_2) - def invoke(self, context) -> IntegerOutput: + def invoke(self, context: InvocationContext) -> IntegerOutput: return IntegerOutput(value=int(self.a / self.b)) @@ -69,7 +70,7 @@ class RandomIntInvocation(BaseInvocation): low: int = InputField(default=0, description=FieldDescriptions.inclusive_low) high: int = InputField(default=np.iinfo(np.int32).max, description=FieldDescriptions.exclusive_high) - def invoke(self, context) -> IntegerOutput: + def invoke(self, context: InvocationContext) -> IntegerOutput: return IntegerOutput(value=np.random.randint(self.low, self.high)) @@ -88,7 +89,7 @@ class RandomFloatInvocation(BaseInvocation): high: float = InputField(default=1.0, description=FieldDescriptions.exclusive_high) decimals: int = InputField(default=2, description=FieldDescriptions.decimal_places) - def invoke(self, context) -> FloatOutput: + def invoke(self, context: InvocationContext) -> FloatOutput: random_float = np.random.uniform(self.low, self.high) rounded_float = round(random_float, self.decimals) return FloatOutput(value=rounded_float) @@ -110,7 +111,7 @@ class FloatToIntegerInvocation(BaseInvocation): default="Nearest", description="The method to use for rounding" ) - def invoke(self, context) -> IntegerOutput: + def invoke(self, context: InvocationContext) -> IntegerOutput: if self.method == "Nearest": return IntegerOutput(value=round(self.value / self.multiple) * self.multiple) elif self.method == "Floor": @@ -128,7 +129,7 @@ class RoundInvocation(BaseInvocation): value: float = InputField(default=0, description="The float value") decimals: int = InputField(default=0, description="The number of decimal places") - def invoke(self, context) -> FloatOutput: + def invoke(self, context: InvocationContext) -> FloatOutput: return FloatOutput(value=round(self.value, self.decimals)) @@ -196,7 +197,7 @@ def no_unrepresentable_results(cls, v: int, info: ValidationInfo): raise ValueError("Result of exponentiation is not an integer") return v - def invoke(self, context) -> IntegerOutput: + def invoke(self, context: InvocationContext) -> IntegerOutput: # Python doesn't support switch statements until 3.10, but InvokeAI supports back to 3.9 if self.operation == "ADD": return IntegerOutput(value=self.a + self.b) @@ -270,7 +271,7 @@ def no_unrepresentable_results(cls, v: float, info: ValidationInfo): raise ValueError("Root operation resulted in a complex number") return v - def invoke(self, context) -> FloatOutput: + def invoke(self, context: InvocationContext) -> FloatOutput: # Python doesn't support switch statements until 3.10, but InvokeAI supports back to 3.9 if self.operation == "ADD": return FloatOutput(value=self.a + self.b) diff --git a/invokeai/app/invocations/metadata.py b/invokeai/app/invocations/metadata.py index 9d74abd8c12..58edfab711a 100644 --- a/invokeai/app/invocations/metadata.py +++ b/invokeai/app/invocations/metadata.py @@ -20,6 +20,7 @@ from invokeai.app.invocations.ip_adapter import IPAdapterModelField from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField from invokeai.app.invocations.t2i_adapter import T2IAdapterField +from invokeai.app.services.shared.invocation_context import InvocationContext from ...version import __version__ @@ -64,7 +65,7 @@ class MetadataItemInvocation(BaseInvocation): label: str = InputField(description=FieldDescriptions.metadata_item_label) value: Any = InputField(description=FieldDescriptions.metadata_item_value, ui_type=UIType.Any) - def invoke(self, context) -> MetadataItemOutput: + def invoke(self, context: InvocationContext) -> MetadataItemOutput: return MetadataItemOutput(item=MetadataItemField(label=self.label, value=self.value)) @@ -81,7 +82,7 @@ class MetadataInvocation(BaseInvocation): description=FieldDescriptions.metadata_item_polymorphic ) - def invoke(self, context) -> MetadataOutput: + def invoke(self, context: InvocationContext) -> MetadataOutput: if isinstance(self.items, MetadataItemField): # single metadata item data = {self.items.label: self.items.value} @@ -100,7 +101,7 @@ class MergeMetadataInvocation(BaseInvocation): collection: list[MetadataField] = InputField(description=FieldDescriptions.metadata_collection) - def invoke(self, context) -> MetadataOutput: + def invoke(self, context: InvocationContext) -> MetadataOutput: data = {} for item in self.collection: data.update(item.model_dump()) @@ -218,7 +219,7 @@ class CoreMetadataInvocation(BaseInvocation): description="The start value used for refiner denoising", ) - def invoke(self, context) -> MetadataOutput: + def invoke(self, context: InvocationContext) -> MetadataOutput: """Collects and outputs a CoreMetadata object""" return MetadataOutput( diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index f81e559e446..6a1fd6d36bc 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -4,6 +4,7 @@ from pydantic import BaseModel, ConfigDict, Field from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.shared.models import FreeUConfig from ...backend.model_management import BaseModelType, ModelType, SubModelType @@ -109,7 +110,7 @@ class MainModelLoaderInvocation(BaseInvocation): model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct) # TODO: precision? - def invoke(self, context) -> ModelLoaderOutput: + def invoke(self, context: InvocationContext) -> ModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name model_type = ModelType.Main @@ -221,7 +222,7 @@ class LoraLoaderInvocation(BaseInvocation): title="CLIP", ) - def invoke(self, context) -> LoraLoaderOutput: + def invoke(self, context: InvocationContext) -> LoraLoaderOutput: if self.lora is None: raise Exception("No LoRA provided") @@ -310,7 +311,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation): title="CLIP 2", ) - def invoke(self, context) -> SDXLLoraLoaderOutput: + def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: if self.lora is None: raise Exception("No LoRA provided") @@ -393,7 +394,7 @@ class VaeLoaderInvocation(BaseInvocation): title="VAE", ) - def invoke(self, context) -> VAEOutput: + def invoke(self, context: InvocationContext) -> VAEOutput: base_model = self.vae_model.base_model model_name = self.vae_model.model_name model_type = ModelType.Vae @@ -448,7 +449,7 @@ class SeamlessModeInvocation(BaseInvocation): seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless") seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless") - def invoke(self, context) -> SeamlessModeOutput: + def invoke(self, context: InvocationContext) -> SeamlessModeOutput: # Conditionally append 'x' and 'y' based on seamless_x and seamless_y unet = copy.deepcopy(self.unet) vae = copy.deepcopy(self.vae) @@ -484,6 +485,6 @@ class FreeUInvocation(BaseInvocation): s1: float = InputField(default=0.9, ge=-1, le=3, description=FieldDescriptions.freeu_s1) s2: float = InputField(default=0.2, ge=-1, le=3, description=FieldDescriptions.freeu_s2) - def invoke(self, context) -> UNetOutput: + def invoke(self, context: InvocationContext) -> UNetOutput: self.unet.freeu_config = FreeUConfig(s1=self.s1, s2=self.s2, b1=self.b1, b2=self.b2) return UNetOutput(unet=self.unet) diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index 41641152f04..78f13cc52d1 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -5,6 +5,7 @@ from pydantic import field_validator from invokeai.app.invocations.fields import FieldDescriptions, InputField, LatentsField, OutputField +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.misc import SEED_MAX from ...backend.util.devices import choose_torch_device, torch_dtype @@ -112,7 +113,7 @@ def modulo_seed(cls, v): """Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range.""" return v % (SEED_MAX + 1) - def invoke(self, context) -> NoiseOutput: + def invoke(self, context: InvocationContext) -> NoiseOutput: noise = get_noise( width=self.width, height=self.height, diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py index a1e318a3802..e7b4d3d9fc5 100644 --- a/invokeai/app/invocations/onnx.py +++ b/invokeai/app/invocations/onnx.py @@ -63,7 +63,7 @@ class ONNXPromptInvocation(BaseInvocation): prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea) clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection) - def invoke(self, context) -> ConditioningOutput: + def invoke(self, context: InvocationContext) -> ConditioningOutput: tokenizer_info = context.services.model_manager.get_model( **self.clip.tokenizer.model_dump(), ) @@ -201,7 +201,7 @@ def ge_one(cls, v): # based on # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375 - def invoke(self, context) -> LatentsOutput: + def invoke(self, context: InvocationContext) -> LatentsOutput: c, _ = context.services.latents.get(self.positive_conditioning.conditioning_name) uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name) graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) @@ -342,7 +342,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata): ) # tiled: bool = InputField(default=False, description="Decode latents by overlaping tiles(less memory consumption)") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: latents = context.services.latents.get(self.latents.latents_name) if self.vae.vae.submodel != SubModelType.VaeDecoder: @@ -417,7 +417,7 @@ class OnnxModelLoaderInvocation(BaseInvocation): description=FieldDescriptions.onnx_main_model, input=Input.Direct, ui_type=UIType.ONNXModel ) - def invoke(self, context) -> ONNXModelLoaderOutput: + def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name model_type = ModelType.ONNX diff --git a/invokeai/app/invocations/param_easing.py b/invokeai/app/invocations/param_easing.py index bf59e87d270..6845637de92 100644 --- a/invokeai/app/invocations/param_easing.py +++ b/invokeai/app/invocations/param_easing.py @@ -40,6 +40,7 @@ from matplotlib.ticker import MaxNLocator from invokeai.app.invocations.primitives import FloatCollectionOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from .baseinvocation import BaseInvocation, invocation from .fields import InputField @@ -62,7 +63,7 @@ class FloatLinearRangeInvocation(BaseInvocation): description="number of values to interpolate over (including start and stop)", ) - def invoke(self, context) -> FloatCollectionOutput: + def invoke(self, context: InvocationContext) -> FloatCollectionOutput: param_list = list(np.linspace(self.start, self.stop, self.steps)) return FloatCollectionOutput(collection=param_list) @@ -130,7 +131,7 @@ class StepParamEasingInvocation(BaseInvocation): # alt_mirror: bool = InputField(default=False, description="alternative mirroring by dual easing") show_easing_plot: bool = InputField(default=False, description="show easing plot") - def invoke(self, context) -> FloatCollectionOutput: + def invoke(self, context: InvocationContext) -> FloatCollectionOutput: log_diagnostics = False # convert from start_step_percent to nearest step <= (steps * start_step_percent) # start_step = int(np.floor(self.num_steps * self.start_step_percent)) diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index ee04345eed8..c90d3230b2b 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -17,6 +17,7 @@ UIComponent, ) from invokeai.app.services.images.images_common import ImageDTO +from invokeai.app.services.shared.invocation_context import InvocationContext from .baseinvocation import ( BaseInvocation, @@ -59,7 +60,7 @@ class BooleanInvocation(BaseInvocation): value: bool = InputField(default=False, description="The boolean value") - def invoke(self, context) -> BooleanOutput: + def invoke(self, context: InvocationContext) -> BooleanOutput: return BooleanOutput(value=self.value) @@ -75,7 +76,7 @@ class BooleanCollectionInvocation(BaseInvocation): collection: list[bool] = InputField(default=[], description="The collection of boolean values") - def invoke(self, context) -> BooleanCollectionOutput: + def invoke(self, context: InvocationContext) -> BooleanCollectionOutput: return BooleanCollectionOutput(collection=self.collection) @@ -108,7 +109,7 @@ class IntegerInvocation(BaseInvocation): value: int = InputField(default=0, description="The integer value") - def invoke(self, context) -> IntegerOutput: + def invoke(self, context: InvocationContext) -> IntegerOutput: return IntegerOutput(value=self.value) @@ -124,7 +125,7 @@ class IntegerCollectionInvocation(BaseInvocation): collection: list[int] = InputField(default=[], description="The collection of integer values") - def invoke(self, context) -> IntegerCollectionOutput: + def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: return IntegerCollectionOutput(collection=self.collection) @@ -155,7 +156,7 @@ class FloatInvocation(BaseInvocation): value: float = InputField(default=0.0, description="The float value") - def invoke(self, context) -> FloatOutput: + def invoke(self, context: InvocationContext) -> FloatOutput: return FloatOutput(value=self.value) @@ -171,7 +172,7 @@ class FloatCollectionInvocation(BaseInvocation): collection: list[float] = InputField(default=[], description="The collection of float values") - def invoke(self, context) -> FloatCollectionOutput: + def invoke(self, context: InvocationContext) -> FloatCollectionOutput: return FloatCollectionOutput(collection=self.collection) @@ -202,7 +203,7 @@ class StringInvocation(BaseInvocation): value: str = InputField(default="", description="The string value", ui_component=UIComponent.Textarea) - def invoke(self, context) -> StringOutput: + def invoke(self, context: InvocationContext) -> StringOutput: return StringOutput(value=self.value) @@ -218,7 +219,7 @@ class StringCollectionInvocation(BaseInvocation): collection: list[str] = InputField(default=[], description="The collection of string values") - def invoke(self, context) -> StringCollectionOutput: + def invoke(self, context: InvocationContext) -> StringCollectionOutput: return StringCollectionOutput(collection=self.collection) @@ -261,7 +262,7 @@ class ImageInvocation( image: ImageField = InputField(description="The image to load") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) return ImageOutput( @@ -283,7 +284,7 @@ class ImageCollectionInvocation(BaseInvocation): collection: list[ImageField] = InputField(description="The collection of image values") - def invoke(self, context) -> ImageCollectionOutput: + def invoke(self, context: InvocationContext) -> ImageCollectionOutput: return ImageCollectionOutput(collection=self.collection) @@ -346,7 +347,7 @@ class LatentsInvocation(BaseInvocation): latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection) - def invoke(self, context) -> LatentsOutput: + def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.latents.get(self.latents.latents_name) return LatentsOutput.build(self.latents.latents_name, latents) @@ -366,7 +367,7 @@ class LatentsCollectionInvocation(BaseInvocation): description="The collection of latents tensors", ) - def invoke(self, context) -> LatentsCollectionOutput: + def invoke(self, context: InvocationContext) -> LatentsCollectionOutput: return LatentsCollectionOutput(collection=self.collection) @@ -397,7 +398,7 @@ class ColorInvocation(BaseInvocation): color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color value") - def invoke(self, context) -> ColorOutput: + def invoke(self, context: InvocationContext) -> ColorOutput: return ColorOutput(color=self.color) @@ -438,7 +439,7 @@ class ConditioningInvocation(BaseInvocation): conditioning: ConditioningField = InputField(description=FieldDescriptions.cond, input=Input.Connection) - def invoke(self, context) -> ConditioningOutput: + def invoke(self, context: InvocationContext) -> ConditioningOutput: return ConditioningOutput(conditioning=self.conditioning) @@ -457,7 +458,7 @@ class ConditioningCollectionInvocation(BaseInvocation): description="The collection of conditioning tensors", ) - def invoke(self, context) -> ConditioningCollectionOutput: + def invoke(self, context: InvocationContext) -> ConditioningCollectionOutput: return ConditioningCollectionOutput(collection=self.collection) diff --git a/invokeai/app/invocations/prompt.py b/invokeai/app/invocations/prompt.py index 4f5ef43a568..234743a0035 100644 --- a/invokeai/app/invocations/prompt.py +++ b/invokeai/app/invocations/prompt.py @@ -6,6 +6,7 @@ from pydantic import field_validator from invokeai.app.invocations.primitives import StringCollectionOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from .baseinvocation import BaseInvocation, invocation from .fields import InputField, UIComponent @@ -29,7 +30,7 @@ class DynamicPromptInvocation(BaseInvocation): max_prompts: int = InputField(default=1, description="The number of prompts to generate") combinatorial: bool = InputField(default=False, description="Whether to use the combinatorial generator") - def invoke(self, context) -> StringCollectionOutput: + def invoke(self, context: InvocationContext) -> StringCollectionOutput: if self.combinatorial: generator = CombinatorialPromptGenerator() prompts = generator.generate(self.prompt, max_prompts=self.max_prompts) @@ -91,7 +92,7 @@ def promptsFromFile( break return prompts - def invoke(self, context) -> StringCollectionOutput: + def invoke(self, context: InvocationContext) -> StringCollectionOutput: prompts = self.promptsFromFile( self.file_path, self.pre_prompt, diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 75a526cfff6..8d51674a046 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -1,4 +1,5 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType +from invokeai.app.services.shared.invocation_context import InvocationContext from ...backend.model_management import ModelType, SubModelType from .baseinvocation import ( @@ -38,7 +39,7 @@ class SDXLModelLoaderInvocation(BaseInvocation): ) # TODO: precision? - def invoke(self, context) -> SDXLModelLoaderOutput: + def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name model_type = ModelType.Main @@ -127,7 +128,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation): ) # TODO: precision? - def invoke(self, context) -> SDXLRefinerModelLoaderOutput: + def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name model_type = ModelType.Main diff --git a/invokeai/app/invocations/strings.py b/invokeai/app/invocations/strings.py index a4c92d9de56..182c976cd77 100644 --- a/invokeai/app/invocations/strings.py +++ b/invokeai/app/invocations/strings.py @@ -2,6 +2,8 @@ import re +from invokeai.app.services.shared.invocation_context import InvocationContext + from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, @@ -32,7 +34,7 @@ class StringSplitNegInvocation(BaseInvocation): string: str = InputField(default="", description="String to split", ui_component=UIComponent.Textarea) - def invoke(self, context) -> StringPosNegOutput: + def invoke(self, context: InvocationContext) -> StringPosNegOutput: p_string = "" n_string = "" brackets_depth = 0 @@ -76,7 +78,7 @@ class StringSplitInvocation(BaseInvocation): default="", description="Delimiter to spilt with. blank will split on the first whitespace" ) - def invoke(self, context) -> String2Output: + def invoke(self, context: InvocationContext) -> String2Output: result = self.string.split(self.delimiter, 1) if len(result) == 2: part1, part2 = result @@ -94,7 +96,7 @@ class StringJoinInvocation(BaseInvocation): string_left: str = InputField(default="", description="String Left", ui_component=UIComponent.Textarea) string_right: str = InputField(default="", description="String Right", ui_component=UIComponent.Textarea) - def invoke(self, context) -> StringOutput: + def invoke(self, context: InvocationContext) -> StringOutput: return StringOutput(value=((self.string_left or "") + (self.string_right or ""))) @@ -106,7 +108,7 @@ class StringJoinThreeInvocation(BaseInvocation): string_middle: str = InputField(default="", description="String Middle", ui_component=UIComponent.Textarea) string_right: str = InputField(default="", description="String Right", ui_component=UIComponent.Textarea) - def invoke(self, context) -> StringOutput: + def invoke(self, context: InvocationContext) -> StringOutput: return StringOutput(value=((self.string_left or "") + (self.string_middle or "") + (self.string_right or ""))) @@ -125,7 +127,7 @@ class StringReplaceInvocation(BaseInvocation): default=False, description="Use search string as a regex expression (non regex is case insensitive)" ) - def invoke(self, context) -> StringOutput: + def invoke(self, context: InvocationContext) -> StringOutput: pattern = self.search_string or "" new_string = self.string or "" if len(pattern) > 0: diff --git a/invokeai/app/invocations/t2i_adapter.py b/invokeai/app/invocations/t2i_adapter.py index 74a098a501c..0f4fe66ada1 100644 --- a/invokeai/app/invocations/t2i_adapter.py +++ b/invokeai/app/invocations/t2i_adapter.py @@ -11,6 +11,7 @@ from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.model_management.models.base import BaseModelType @@ -89,7 +90,7 @@ def validate_begin_end_step_percent(self): validate_begin_end_step(self.begin_step_percent, self.end_step_percent) return self - def invoke(self, context) -> T2IAdapterOutput: + def invoke(self, context: InvocationContext) -> T2IAdapterOutput: return T2IAdapterOutput( t2i_adapter=T2IAdapterField( image=self.image, diff --git a/invokeai/app/invocations/tiles.py b/invokeai/app/invocations/tiles.py index 0b4c472696b..19ece423761 100644 --- a/invokeai/app/invocations/tiles.py +++ b/invokeai/app/invocations/tiles.py @@ -13,6 +13,7 @@ ) from invokeai.app.invocations.fields import ImageField, Input, InputField, OutputField, WithMetadata from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.tiles.tiles import ( calc_tiles_even_split, calc_tiles_min_overlap, @@ -56,7 +57,7 @@ class CalculateImageTilesInvocation(BaseInvocation): description="The target overlap, in pixels, between adjacent tiles. Adjacent tiles will overlap by at least this amount", ) - def invoke(self, context) -> CalculateImageTilesOutput: + def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput: tiles = calc_tiles_with_overlap( image_height=self.image_height, image_width=self.image_width, @@ -99,7 +100,7 @@ class CalculateImageTilesEvenSplitInvocation(BaseInvocation): description="The overlap, in pixels, between adjacent tiles.", ) - def invoke(self, context) -> CalculateImageTilesOutput: + def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput: tiles = calc_tiles_even_split( image_height=self.image_height, image_width=self.image_width, @@ -129,7 +130,7 @@ class CalculateImageTilesMinimumOverlapInvocation(BaseInvocation): tile_height: int = InputField(ge=1, default=576, description="The tile height, in pixels.") min_overlap: int = InputField(default=128, ge=0, description="Minimum overlap between adjacent tiles, in pixels.") - def invoke(self, context) -> CalculateImageTilesOutput: + def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput: tiles = calc_tiles_min_overlap( image_height=self.image_height, image_width=self.image_width, @@ -174,7 +175,7 @@ class TileToPropertiesInvocation(BaseInvocation): tile: Tile = InputField(description="The tile to split into properties.") - def invoke(self, context) -> TileToPropertiesOutput: + def invoke(self, context: InvocationContext) -> TileToPropertiesOutput: return TileToPropertiesOutput( coords_left=self.tile.coords.left, coords_right=self.tile.coords.right, @@ -211,7 +212,7 @@ class PairTileImageInvocation(BaseInvocation): image: ImageField = InputField(description="The tile image.") tile: Tile = InputField(description="The tile properties.") - def invoke(self, context) -> PairTileImageOutput: + def invoke(self, context: InvocationContext) -> PairTileImageOutput: return PairTileImageOutput( tile_with_image=TileWithImage( tile=self.tile, @@ -247,7 +248,7 @@ class MergeTilesToImageInvocation(BaseInvocation, WithMetadata): description="The amount to blend adjacent tiles in pixels. Must be <= the amount of overlap between adjacent tiles.", ) - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: images = [twi.image for twi in self.tiles_with_images] tiles = [twi.tile for twi in self.tiles_with_images] diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index ef174809860..71ef7ca3aa0 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -10,6 +10,7 @@ from invokeai.app.invocations.fields import ImageField from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN from invokeai.backend.util.devices import choose_torch_device @@ -42,7 +43,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata): model_config = ConfigDict(protected_namespaces=()) - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) models_path = context.config.get().models_path diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index c0699eb96bb..3df230f5ee7 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -17,6 +17,7 @@ invocation_output, ) from invokeai.app.invocations.fields import Input, InputField, OutputField, UIType +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.misc import uuid_string # in 3.10 this would be "from types import NoneType" @@ -201,7 +202,7 @@ class GraphInvocation(BaseInvocation): # TODO: figure out how to create a default here graph: "Graph" = InputField(description="The graph to run", default=None) - def invoke(self, context) -> GraphInvocationOutput: + def invoke(self, context: InvocationContext) -> GraphInvocationOutput: """Invoke with provided services and return outputs.""" return GraphInvocationOutput() @@ -227,7 +228,7 @@ class IterateInvocation(BaseInvocation): ) index: int = InputField(description="The index, will be provided on executed iterators", default=0, ui_hidden=True) - def invoke(self, context) -> IterateInvocationOutput: + def invoke(self, context: InvocationContext) -> IterateInvocationOutput: """Produces the outputs as values""" return IterateInvocationOutput(item=self.collection[self.index], index=self.index, total=len(self.collection)) @@ -254,7 +255,7 @@ class CollectInvocation(BaseInvocation): description="The collection, will be provided on execution", default=[], ui_hidden=True ) - def invoke(self, context) -> CollectInvocationOutput: + def invoke(self, context: InvocationContext) -> CollectInvocationOutput: """Invoke with provided services and return outputs.""" return CollectInvocationOutput(collection=copy.copy(self.collection)) diff --git a/tests/aa_nodes/test_nodes.py b/tests/aa_nodes/test_nodes.py index 559457c0e11..aab3d9c7b4b 100644 --- a/tests/aa_nodes/test_nodes.py +++ b/tests/aa_nodes/test_nodes.py @@ -8,6 +8,7 @@ ) from invokeai.app.invocations.fields import InputField, OutputField from invokeai.app.invocations.image import ImageField +from invokeai.app.services.shared.invocation_context import InvocationContext # Define test invocations before importing anything that uses invocations @@ -20,7 +21,7 @@ class ListPassThroughInvocationOutput(BaseInvocationOutput): class ListPassThroughInvocation(BaseInvocation): collection: list[ImageField] = InputField(default=[]) - def invoke(self, context) -> ListPassThroughInvocationOutput: + def invoke(self, context: InvocationContext) -> ListPassThroughInvocationOutput: return ListPassThroughInvocationOutput(collection=self.collection) @@ -33,13 +34,13 @@ class PromptTestInvocationOutput(BaseInvocationOutput): class PromptTestInvocation(BaseInvocation): prompt: str = InputField(default="") - def invoke(self, context) -> PromptTestInvocationOutput: + def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput: return PromptTestInvocationOutput(prompt=self.prompt) @invocation("test_error", version="1.0.0") class ErrorInvocation(BaseInvocation): - def invoke(self, context) -> PromptTestInvocationOutput: + def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput: raise Exception("This invocation is supposed to fail") @@ -53,7 +54,7 @@ class TextToImageTestInvocation(BaseInvocation): prompt: str = InputField(default="") prompt2: str = InputField(default="") - def invoke(self, context) -> ImageTestInvocationOutput: + def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput: return ImageTestInvocationOutput(image=ImageField(image_name=self.id)) @@ -62,7 +63,7 @@ class ImageToImageTestInvocation(BaseInvocation): prompt: str = InputField(default="") image: Union[ImageField, None] = InputField(default=None) - def invoke(self, context) -> ImageTestInvocationOutput: + def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput: return ImageTestInvocationOutput(image=ImageField(image_name=self.id)) @@ -75,7 +76,7 @@ class PromptCollectionTestInvocationOutput(BaseInvocationOutput): class PromptCollectionTestInvocation(BaseInvocation): collection: list[str] = InputField() - def invoke(self, context) -> PromptCollectionTestInvocationOutput: + def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput: return PromptCollectionTestInvocationOutput(collection=self.collection.copy()) @@ -88,7 +89,7 @@ class AnyTypeTestInvocationOutput(BaseInvocationOutput): class AnyTypeTestInvocation(BaseInvocation): value: Any = InputField(default=None) - def invoke(self, context) -> AnyTypeTestInvocationOutput: + def invoke(self, context: InvocationContext) -> AnyTypeTestInvocationOutput: return AnyTypeTestInvocationOutput(value=self.value) @@ -96,7 +97,7 @@ def invoke(self, context) -> AnyTypeTestInvocationOutput: class PolymorphicStringTestInvocation(BaseInvocation): value: Union[str, list[str]] = InputField(default="") - def invoke(self, context) -> PromptCollectionTestInvocationOutput: + def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput: if isinstance(self.value, str): return PromptCollectionTestInvocationOutput(collection=[self.value]) return PromptCollectionTestInvocationOutput(collection=self.value) From e52434cb9985ee58d6efcbe4c0b8df5df3b4af55 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 5 Feb 2024 17:40:49 +1100 Subject: [PATCH 016/100] feat(nodes): add boards interface to invocation context --- .../app/services/shared/invocation_context.py | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index a849d6b17a2..cbcaa6a5489 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -7,6 +7,7 @@ from torch import Tensor from invokeai.app.invocations.fields import MetadataField, WithMetadata +from invokeai.app.services.boards.boards_common import BoardDTO from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin from invokeai.app.services.images.images_common import ImageDTO @@ -63,6 +64,54 @@ class InvocationContextData: """The workflow associated with this queue item, if any.""" +class BoardsInterface: + def __init__(self, services: InvocationServices) -> None: + def create(board_name: str) -> BoardDTO: + """ + Creates a board. + + :param board_name: The name of the board to create. + """ + return services.boards.create(board_name) + + def get_dto(board_id: str) -> BoardDTO: + """ + Gets a board DTO. + + :param board_id: The ID of the board to get. + """ + return services.boards.get_dto(board_id) + + def get_all() -> list[BoardDTO]: + """ + Gets all boards. + """ + return services.boards.get_all() + + def add_image_to_board(board_id: str, image_name: str) -> None: + """ + Adds an image to a board. + + :param board_id: The ID of the board to add the image to. + :param image_name: The name of the image to add to the board. + """ + services.board_images.add_image_to_board(board_id, image_name) + + def get_all_image_names_for_board(board_id: str) -> list[str]: + """ + Gets all image names for a board. + + :param board_id: The ID of the board to get the image names for. + """ + return services.board_images.get_all_board_image_names_for_board(board_id) + + self.create = create + self.get_dto = get_dto + self.get_all = get_all + self.add_image_to_board = add_image_to_board + self.get_all_image_names_for_board = get_all_image_names_for_board + + class LoggerInterface: def __init__(self, services: InvocationServices) -> None: def debug(message: str) -> None: @@ -427,6 +476,7 @@ def __init__( logger: LoggerInterface, config: ConfigInterface, util: UtilInterface, + boards: BoardsInterface, data: InvocationContextData, services: InvocationServices, ) -> None: @@ -444,6 +494,8 @@ def __init__( """Provides access to the app's config.""" self.util = util """Provides utility methods.""" + self.boards = boards + """Provides methods to interact with boards.""" self.data = data """Provides data about the current queue item and invocation.""" self.__services = services @@ -554,6 +606,7 @@ def build_invocation_context( config = ConfigInterface(services=services) util = UtilInterface(services=services, context_data=context_data) conditioning = ConditioningInterface(services=services, context_data=context_data) + boards = BoardsInterface(services=services) ctx = InvocationContext( images=images, @@ -565,6 +618,7 @@ def build_invocation_context( util=util, conditioning=conditioning, services=services, + boards=boards, ) return ctx From bf48e8a03a51253922201bf1b892ee90c3c332e9 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 5 Feb 2024 17:48:32 +1100 Subject: [PATCH 017/100] feat(nodes): export more things from `invocation_api" --- invokeai/invocation_api/__init__.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/invokeai/invocation_api/__init__.py b/invokeai/invocation_api/__init__.py index e867ec3cc4e..e80bc26a003 100644 --- a/invokeai/invocation_api/__init__.py +++ b/invokeai/invocation_api/__init__.py @@ -47,8 +47,14 @@ StringCollectionOutput, StringOutput, ) +from invokeai.app.services.boards.boards_common import BoardDTO +from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.image_records.image_records_common import ImageCategory from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID +from invokeai.backend.model_management.model_manager import ModelInfo +from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType +from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, ConditioningFieldData, @@ -101,9 +107,23 @@ "StringOutput", # invokeai.app.services.image_records.image_records_common "ImageCategory", + # invokeai.app.services.boards.boards_common + "BoardDTO", # invokeai.backend.stable_diffusion.diffusion.conditioning_data "BasicConditioningInfo", "ConditioningFieldData", "ExtraConditioningInfo", "SDXLConditioningInfo", + # invokeai.backend.stable_diffusion.diffusers_pipeline + "PipelineIntermediateState", + # invokeai.app.services.workflow_records.workflow_records_common + "WorkflowWithoutID", + # invokeai.app.services.config.config_default + "InvokeAIAppConfig", + # invokeai.backend.model_management.model_manager + "ModelInfo", + # invokeai.backend.model_management.models.base + "BaseModelType", + "ModelType", + "SubModelType", ] From 3f5ab02da977e3c29c9311ce24bce7e170bd9779 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 5 Feb 2024 18:06:01 +1100 Subject: [PATCH 018/100] chore(nodes): add comments for ConfigInterface --- invokeai/app/services/shared/invocation_context.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index cbcaa6a5489..cb989cb15e0 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -392,7 +392,7 @@ class ConfigInterface: def __init__(self, services: InvocationServices) -> None: def get() -> InvokeAIAppConfig: """ - Gets the app's config. + Gets the app's config. The config is read-only; attempts to mutate it will raise an error. """ # The config can be changed at runtime. @@ -400,6 +400,7 @@ def get() -> InvokeAIAppConfig: # We don't want nodes doing this, so we make a frozen copy. config = services.configuration.get_config() + # TODO(psyche): If config cannot be changed at runtime, should we cache this? frozen_config = config.model_copy(update={"model_config": ConfigDict(frozen=True)}) return frozen_config From 9c1e52b1ef6a0606594848f2f65457c10a4cf1e5 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 6 Feb 2024 00:37:18 +1100 Subject: [PATCH 019/100] tests(nodes): fix mock InvocationContext --- tests/aa_nodes/test_graph_execution_state.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/aa_nodes/test_graph_execution_state.py b/tests/aa_nodes/test_graph_execution_state.py index 3577a78ae2c..1612cbe7198 100644 --- a/tests/aa_nodes/test_graph_execution_state.py +++ b/tests/aa_nodes/test_graph_execution_state.py @@ -93,6 +93,7 @@ def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[B logger=None, models=None, util=None, + boards=None, services=None, ) ) From afbe889d357392f7109f67fed0cc1d99c52290cc Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 6 Feb 2024 10:22:58 +1100 Subject: [PATCH 020/100] fix(nodes): restore missing context type annotations --- invokeai/app/invocations/latent.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 2cc84f80a73..5e36e73ec8f 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -106,7 +106,7 @@ class SchedulerInvocation(BaseInvocation): ui_type=UIType.Scheduler, ) - def invoke(self, context) -> SchedulerOutput: + def invoke(self, context: InvocationContext) -> SchedulerOutput: return SchedulerOutput(scheduler=self.scheduler) @@ -141,7 +141,7 @@ def prep_mask_tensor(self, mask_image): return mask_tensor @torch.no_grad() - def invoke(self, context) -> DenoiseMaskOutput: + def invoke(self, context: InvocationContext) -> DenoiseMaskOutput: if self.image is not None: image = context.images.get_pil(self.image.image_name) image = image_resized_to_grid_as_tensor(image.convert("RGB")) @@ -630,7 +630,7 @@ def prep_inpaint_mask(self, context: InvocationContext, latents): return 1 - mask, masked_latents @torch.no_grad() - def invoke(self, context) -> LatentsOutput: + def invoke(self, context: InvocationContext) -> LatentsOutput: with SilenceWarnings(): # this quenches NSFW nag from diffusers seed = None noise = None @@ -777,7 +777,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata): fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32) @torch.no_grad() - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: latents = context.latents.get(self.latents.latents_name) vae_info = context.models.load(**self.vae.vae.model_dump()) @@ -868,7 +868,7 @@ class ResizeLatentsInvocation(BaseInvocation): mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode) antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) - def invoke(self, context) -> LatentsOutput: + def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.latents.get(self.latents.latents_name) # TODO: @@ -909,7 +909,7 @@ class ScaleLatentsInvocation(BaseInvocation): mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode) antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) - def invoke(self, context) -> LatentsOutput: + def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.latents.get(self.latents.latents_name) # TODO: @@ -998,7 +998,7 @@ def vae_encode(vae_info, upcast, tiled, image_tensor): return latents @torch.no_grad() - def invoke(self, context) -> LatentsOutput: + def invoke(self, context: InvocationContext) -> LatentsOutput: image = context.images.get_pil(self.image.image_name) vae_info = context.models.load(**self.vae.vae.model_dump()) @@ -1046,7 +1046,7 @@ class BlendLatentsInvocation(BaseInvocation): ) alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha) - def invoke(self, context) -> LatentsOutput: + def invoke(self, context: InvocationContext) -> LatentsOutput: latents_a = context.latents.get(self.latents_a.latents_name) latents_b = context.latents.get(self.latents_b.latents_name) @@ -1147,7 +1147,7 @@ class CropLatentsCoreInvocation(BaseInvocation): description="The height (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.", ) - def invoke(self, context) -> LatentsOutput: + def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.latents.get(self.latents.latents_name) x1 = self.x // LATENT_SCALE_FACTOR From b4c774896aef541e2221dc6f185729234b63099c Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 14:24:05 +1100 Subject: [PATCH 021/100] feat(nodes): do not hide `services` in invocation context interfaces --- .../app/services/shared/invocation_context.py | 675 ++++++++---------- 1 file changed, 317 insertions(+), 358 deletions(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index cb989cb15e0..54c50bcf76b 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -64,379 +64,338 @@ class InvocationContextData: """The workflow associated with this queue item, if any.""" -class BoardsInterface: - def __init__(self, services: InvocationServices) -> None: - def create(board_name: str) -> BoardDTO: - """ - Creates a board. - - :param board_name: The name of the board to create. - """ - return services.boards.create(board_name) - - def get_dto(board_id: str) -> BoardDTO: - """ - Gets a board DTO. - - :param board_id: The ID of the board to get. - """ - return services.boards.get_dto(board_id) - - def get_all() -> list[BoardDTO]: - """ - Gets all boards. - """ - return services.boards.get_all() - - def add_image_to_board(board_id: str, image_name: str) -> None: - """ - Adds an image to a board. - - :param board_id: The ID of the board to add the image to. - :param image_name: The name of the image to add to the board. - """ - services.board_images.add_image_to_board(board_id, image_name) - - def get_all_image_names_for_board(board_id: str) -> list[str]: - """ - Gets all image names for a board. - - :param board_id: The ID of the board to get the image names for. - """ - return services.board_images.get_all_board_image_names_for_board(board_id) - - self.create = create - self.get_dto = get_dto - self.get_all = get_all - self.add_image_to_board = add_image_to_board - self.get_all_image_names_for_board = get_all_image_names_for_board - - -class LoggerInterface: - def __init__(self, services: InvocationServices) -> None: - def debug(message: str) -> None: - """ - Logs a debug message. - - :param message: The message to log. - """ - services.logger.debug(message) - - def info(message: str) -> None: - """ - Logs an info message. - - :param message: The message to log. - """ - services.logger.info(message) - - def warning(message: str) -> None: - """ - Logs a warning message. - - :param message: The message to log. - """ - services.logger.warning(message) - - def error(message: str) -> None: - """ - Logs an error message. - - :param message: The message to log. - """ - services.logger.error(message) - - self.debug = debug - self.info = info - self.warning = warning - self.error = error - - -class ImagesInterface: +class InvocationContextInterface: def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None: - def save( - image: Image, - board_id: Optional[str] = None, - image_category: ImageCategory = ImageCategory.GENERAL, - metadata: Optional[MetadataField] = None, - ) -> ImageDTO: - """ - Saves an image, returning its DTO. - - If the current queue item has a workflow or metadata, it is automatically saved with the image. - - :param image: The image to save, as a PIL image. - :param board_id: The board ID to add the image to, if it should be added. - :param image_category: The category of the image. Only the GENERAL category is added \ - to the gallery. - :param metadata: The metadata to save with the image, if it should have any. If the \ - invocation inherits from `WithMetadata`, that metadata will be used automatically. \ - **Use this only if you want to override or provide metadata manually!** - """ - - # If the invocation inherits metadata, use that. Else, use the metadata passed in. - metadata_ = ( - context_data.invocation.metadata if isinstance(context_data.invocation, WithMetadata) else metadata - ) - - return services.images.create( - image=image, - is_intermediate=context_data.invocation.is_intermediate, - image_category=image_category, - board_id=board_id, - metadata=metadata_, - image_origin=ResourceOrigin.INTERNAL, - workflow=context_data.workflow, - session_id=context_data.session_id, - node_id=context_data.invocation.id, - ) - - def get_pil(image_name: str) -> Image: - """ - Gets an image as a PIL Image object. - - :param image_name: The name of the image to get. - """ - return services.images.get_pil_image(image_name) - - def get_metadata(image_name: str) -> Optional[MetadataField]: - """ - Gets an image's metadata, if it has any. - - :param image_name: The name of the image to get the metadata for. - """ - return services.images.get_metadata(image_name) - - def get_dto(image_name: str) -> ImageDTO: - """ - Gets an image as an ImageDTO object. - - :param image_name: The name of the image to get. - """ - return services.images.get_dto(image_name) - - def update( - image_name: str, - board_id: Optional[str] = None, - is_intermediate: Optional[bool] = False, - ) -> ImageDTO: - """ - Updates an image, returning its updated DTO. - - It is not suggested to update images saved by earlier nodes, as this can cause confusion for users. - - If you use this method, you *must* return the image as an :class:`ImageOutput` for the gallery to - get the updated image. - - :param image_name: The name of the image to update. - :param board_id: The board ID to add the image to, if it should be added. - :param is_intermediate: Whether the image is an intermediate. Intermediate images aren't added to the gallery. - """ - if is_intermediate is not None: - services.images.update(image_name, ImageRecordChanges(is_intermediate=is_intermediate)) - if board_id is None: - services.board_images.remove_image_from_board(image_name) - else: - services.board_images.add_image_to_board(image_name, board_id) - return services.images.get_dto(image_name) - - self.save = save - self.get_pil = get_pil - self.get_metadata = get_metadata - self.get_dto = get_dto - self.update = update - - -class LatentsInterface: - def __init__( + self._services = services + self._context_data = context_data + + +class BoardsInterface(InvocationContextInterface): + def create(self, board_name: str) -> BoardDTO: + """ + Creates a board. + + :param board_name: The name of the board to create. + """ + return self._services.boards.create(board_name) + + def get_dto(self, board_id: str) -> BoardDTO: + """ + Gets a board DTO. + + :param board_id: The ID of the board to get. + """ + return self._services.boards.get_dto(board_id) + + def get_all(self) -> list[BoardDTO]: + """ + Gets all boards. + """ + return self._services.boards.get_all() + + def add_image_to_board(self, board_id: str, image_name: str) -> None: + """ + Adds an image to a board. + + :param board_id: The ID of the board to add the image to. + :param image_name: The name of the image to add to the board. + """ + return self._services.board_images.add_image_to_board(board_id, image_name) + + def get_all_image_names_for_board(self, board_id: str) -> list[str]: + """ + Gets all image names for a board. + + :param board_id: The ID of the board to get the image names for. + """ + return self._services.board_images.get_all_board_image_names_for_board(board_id) + + +class LoggerInterface(InvocationContextInterface): + def debug(self, message: str) -> None: + """ + Logs a debug message. + + :param message: The message to log. + """ + self._services.logger.debug(message) + + def info(self, message: str) -> None: + """ + Logs an info message. + + :param message: The message to log. + """ + self._services.logger.info(message) + + def warning(self, message: str) -> None: + """ + Logs a warning message. + + :param message: The message to log. + """ + self._services.logger.warning(message) + + def error(self, message: str) -> None: + """ + Logs an error message. + + :param message: The message to log. + """ + self._services.logger.error(message) + + +class ImagesInterface(InvocationContextInterface): + def save( self, - services: InvocationServices, - context_data: InvocationContextData, - ) -> None: - def save(tensor: Tensor) -> str: - """ - Saves a latents tensor, returning its name. - - :param tensor: The latents tensor to save. - """ - - # Previously, we added a suffix indicating the type of Tensor we were saving, e.g. - # "mask", "noise", "masked_latents", etc. - # - # Retaining that capability in this wrapper would require either many different methods - # to save latents, or extra args for this method. Instead of complicating the API, we - # will use the same naming scheme for all latents. - # - # This has a very minor impact as we don't use them after a session completes. - - # Previously, invocations chose the name for their latents. This is a bit risky, so we - # will generate a name for them instead. We use a uuid to ensure the name is unique. - # - # Because the name of the latents file will includes the session and invocation IDs, - # we don't need to worry about collisions. A truncated UUIDv4 is fine. - - name = f"{context_data.session_id}__{context_data.invocation.id}__{uuid_string()[:7]}" - services.latents.save( - name=name, - data=tensor, - ) - return name - - def get(latents_name: str) -> Tensor: - """ - Gets a latents tensor by name. - - :param latents_name: The name of the latents tensor to get. - """ - return services.latents.get(latents_name) - - self.save = save - self.get = get - - -class ConditioningInterface: - def __init__( + image: Image, + board_id: Optional[str] = None, + image_category: ImageCategory = ImageCategory.GENERAL, + metadata: Optional[MetadataField] = None, + ) -> ImageDTO: + """ + Saves an image, returning its DTO. + + If the current queue item has a workflow or metadata, it is automatically saved with the image. + + :param image: The image to save, as a PIL image. + :param board_id: The board ID to add the image to, if it should be added. + :param image_category: The category of the image. Only the GENERAL category is added \ + to the gallery. + :param metadata: The metadata to save with the image, if it should have any. If the \ + invocation inherits from `WithMetadata`, that metadata will be used automatically. \ + **Use this only if you want to override or provide metadata manually!** + """ + + # If the invocation inherits metadata, use that. Else, use the metadata passed in. + metadata_ = ( + self._context_data.invocation.metadata + if isinstance(self._context_data.invocation, WithMetadata) + else metadata + ) + + return self._services.images.create( + image=image, + is_intermediate=self._context_data.invocation.is_intermediate, + image_category=image_category, + board_id=board_id, + metadata=metadata_, + image_origin=ResourceOrigin.INTERNAL, + workflow=self._context_data.workflow, + session_id=self._context_data.session_id, + node_id=self._context_data.invocation.id, + ) + + def get_pil(self, image_name: str) -> Image: + """ + Gets an image as a PIL Image object. + + :param image_name: The name of the image to get. + """ + return self._services.images.get_pil_image(image_name) + + def get_metadata(self, image_name: str) -> Optional[MetadataField]: + """ + Gets an image's metadata, if it has any. + + :param image_name: The name of the image to get the metadata for. + """ + return self._services.images.get_metadata(image_name) + + def get_dto(self, image_name: str) -> ImageDTO: + """ + Gets an image as an ImageDTO object. + + :param image_name: The name of the image to get. + """ + return self._services.images.get_dto(image_name) + + def update( self, - services: InvocationServices, - context_data: InvocationContextData, - ) -> None: - # TODO(psyche): We are (ab)using the latents storage service as a general pickle storage - # service, but it is typed to work with Tensors only. We have to fudge the types here. + image_name: str, + board_id: Optional[str] = None, + is_intermediate: Optional[bool] = False, + ) -> ImageDTO: + """ + Updates an image, returning its updated DTO. + + It is not suggested to update images saved by earlier nodes, as this can cause confusion for users. - def save(conditioning_data: ConditioningFieldData) -> str: - """ - Saves a conditioning data object, returning its name. + If you use this method, you *must* return the image as an :class:`ImageOutput` for the gallery to + get the updated image. - :param conditioning_data: The conditioning data to save. - """ + :param image_name: The name of the image to update. + :param board_id: The board ID to add the image to, if it should be added. + :param is_intermediate: Whether the image is an intermediate. Intermediate images aren't added to the gallery. + """ + if is_intermediate is not None: + self._services.images.update(image_name, ImageRecordChanges(is_intermediate=is_intermediate)) + if board_id is None: + self._services.board_images.remove_image_from_board(image_name) + else: + self._services.board_images.add_image_to_board(image_name, board_id) + return self._services.images.get_dto(image_name) - # Conditioning data is *not* a Tensor, so we will suffix it to indicate this. - # - # See comment for `LatentsInterface.save` for more info about this method (it's very - # similar). - name = f"{context_data.session_id}__{context_data.invocation.id}__{uuid_string()[:7]}__conditioning" - services.latents.save( - name=name, - data=conditioning_data, # type: ignore [arg-type] - ) - return name +class LatentsInterface(InvocationContextInterface): + def save(self, tensor: Tensor) -> str: + """ + Saves a latents tensor, returning its name. - def get(conditioning_name: str) -> ConditioningFieldData: - """ - Gets conditioning data by name. + :param tensor: The latents tensor to save. + """ - :param conditioning_name: The name of the conditioning data to get. - """ + # Previously, we added a suffix indicating the type of Tensor we were saving, e.g. + # "mask", "noise", "masked_latents", etc. + # + # Retaining that capability in this wrapper would require either many different methods + # to save latents, or extra args for this method. Instead of complicating the API, we + # will use the same naming scheme for all latents. + # + # This has a very minor impact as we don't use them after a session completes. + + # Previously, invocations chose the name for their latents. This is a bit risky, so we + # will generate a name for them instead. We use a uuid to ensure the name is unique. + # + # Because the name of the latents file will includes the session and invocation IDs, + # we don't need to worry about collisions. A truncated UUIDv4 is fine. + + name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}" + self._services.latents.save( + name=name, + data=tensor, + ) + return name + + def get(self, latents_name: str) -> Tensor: + """ + Gets a latents tensor by name. - return services.latents.get(conditioning_name) # type: ignore [return-value] + :param latents_name: The name of the latents tensor to get. + """ + return self._services.latents.get(latents_name) - self.save = save - self.get = get +class ConditioningInterface(InvocationContextInterface): + # TODO(psyche): We are (ab)using the latents storage service as a general pickle storage + # service, but it is typed to work with Tensors only. We have to fudge the types here. + def save(self, conditioning_data: ConditioningFieldData) -> str: + """ + Saves a conditioning data object, returning its name. -class ModelsInterface: - def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None: - def exists(model_name: str, base_model: BaseModelType, model_type: ModelType) -> bool: - """ - Checks if a model exists. - - :param model_name: The name of the model to check. - :param base_model: The base model of the model to check. - :param model_type: The type of the model to check. - """ - return services.model_manager.model_exists(model_name, base_model, model_type) - - def load( - model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None - ) -> ModelInfo: - """ - Loads a model, returning its `ModelInfo` object. - - :param model_name: The name of the model to get. - :param base_model: The base model of the model to get. - :param model_type: The type of the model to get. - :param submodel: The submodel of the model to get. - """ - - # During this call, the model manager emits events with model loading status. The model - # manager itself has access to the events services, but does not have access to the - # required metadata for the events. - # - # For example, it needs access to the node's ID so that the events can be associated - # with the execution of a specific node. - # - # While this is available within the node, it's tedious to need to pass it in on every - # call. We can avoid that by wrapping the method here. - - return services.model_manager.get_model( - model_name, base_model, model_type, submodel, context_data=context_data - ) - - def get_info(model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: - """ - Gets a model's info, an dict-like object. - - :param model_name: The name of the model to get. - :param base_model: The base model of the model to get. - :param model_type: The type of the model to get. - """ - return services.model_manager.model_info(model_name, base_model, model_type) - - self.exists = exists - self.load = load - self.get_info = get_info - - -class ConfigInterface: - def __init__(self, services: InvocationServices) -> None: - def get() -> InvokeAIAppConfig: - """ - Gets the app's config. The config is read-only; attempts to mutate it will raise an error. - """ - - # The config can be changed at runtime. - # - # We don't want nodes doing this, so we make a frozen copy. - - config = services.configuration.get_config() - # TODO(psyche): If config cannot be changed at runtime, should we cache this? - frozen_config = config.model_copy(update={"model_config": ConfigDict(frozen=True)}) - return frozen_config - - self.get = get - - -class UtilInterface: - def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None: - def sd_step_callback( - intermediate_state: PipelineIntermediateState, - base_model: BaseModelType, - ) -> None: - """ - The step callback emits a progress event with the current step, the total number of - steps, a preview image, and some other internal metadata. + :param conditioning_context_data: The conditioning data to save. + """ + + # Conditioning data is *not* a Tensor, so we will suffix it to indicate this. + # + # See comment for `LatentsInterface.save` for more info about this method (it's very + # similar). - This should be called after each denoising step. + name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}__conditioning" + self._services.latents.save( + name=name, + data=conditioning_data, # type: ignore [arg-type] + ) + return name - :param intermediate_state: The intermediate state of the diffusion pipeline. - :param base_model: The base model for the current denoising step. - """ + def get(self, conditioning_name: str) -> ConditioningFieldData: + """ + Gets conditioning data by name. + + :param conditioning_name: The name of the conditioning data to get. + """ - # The step callback needs access to the events and the invocation queue services, but this - # represents a dangerous level of access. - # - # We wrap the step callback so that nodes do not have direct access to these services. + return self._services.latents.get(conditioning_name) # type: ignore [return-value] + + +class ModelsInterface(InvocationContextInterface): + def exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> bool: + """ + Checks if a model exists. + + :param model_name: The name of the model to check. + :param base_model: The base model of the model to check. + :param model_type: The type of the model to check. + """ + return self._services.model_manager.model_exists(model_name, base_model, model_type) + + def load( + self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None + ) -> ModelInfo: + """ + Loads a model, returning its `ModelInfo` object. + + :param model_name: The name of the model to get. + :param base_model: The base model of the model to get. + :param model_type: The type of the model to get. + :param submodel: The submodel of the model to get. + """ + + # During this call, the model manager emits events with model loading status. The model + # manager itself has access to the events services, but does not have access to the + # required metadata for the events. + # + # For example, it needs access to the node's ID so that the events can be associated + # with the execution of a specific node. + # + # While this is available within the node, it's tedious to need to pass it in on every + # call. We can avoid that by wrapping the method here. + + return self._services.model_manager.get_model( + model_name, base_model, model_type, submodel, context_data=self._context_data + ) + + def get_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: + """ + Gets a model's info, an dict-like object. + + :param model_name: The name of the model to get. + :param base_model: The base model of the model to get. + :param model_type: The type of the model to get. + """ + return self._services.model_manager.model_info(model_name, base_model, model_type) + + +class ConfigInterface(InvocationContextInterface): + def get(self) -> InvokeAIAppConfig: + """ + Gets the app's config. The config is read-only; attempts to mutate it will raise an error. + """ + + # The config can be changed at runtime. + # + # We don't want nodes doing this, so we make a frozen copy. + + config = self._services.configuration.get_config() + # TODO(psyche): If config cannot be changed at runtime, should we cache this? + frozen_config = config.model_copy(update={"model_config": ConfigDict(frozen=True)}) + return frozen_config + + +class UtilInterface(InvocationContextInterface): + def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None: + """ + The step callback emits a progress event with the current step, the total number of + steps, a preview image, and some other internal metadata. + + This should be called after each denoising step. + + :param intermediate_state: The intermediate state of the diffusion pipeline. + :param base_model: The base model for the current denoising step. + """ - stable_diffusion_step_callback( - context_data=context_data, - intermediate_state=intermediate_state, - base_model=base_model, - invocation_queue=services.queue, - events=services.events, - ) + # The step callback needs access to the events and the invocation queue services, but this + # represents a dangerous level of access. + # + # We wrap the step callback so that nodes do not have direct access to these services. - self.sd_step_callback = sd_step_callback + stable_diffusion_step_callback( + context_data=self._context_data, + intermediate_state=intermediate_state, + base_model=base_model, + invocation_queue=self._services.queue, + events=self._services.events, + ) deprecation_version = "3.7.0" @@ -600,14 +559,14 @@ def build_invocation_context( :param invocation_context_data: The invocation context data. """ - logger = LoggerInterface(services=services) + logger = LoggerInterface(services=services, context_data=context_data) images = ImagesInterface(services=services, context_data=context_data) latents = LatentsInterface(services=services, context_data=context_data) models = ModelsInterface(services=services, context_data=context_data) - config = ConfigInterface(services=services) + config = ConfigInterface(services=services, context_data=context_data) util = UtilInterface(services=services, context_data=context_data) conditioning = ConditioningInterface(services=services, context_data=context_data) - boards = BoardsInterface(services=services) + boards = BoardsInterface(services=services, context_data=context_data) ctx = InvocationContext( images=images, From 889a26c5b64fdfdadc9f834b331076084a9cf3b2 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 14:28:29 +1100 Subject: [PATCH 022/100] feat(nodes): cache invocation interface config --- .../app/services/shared/invocation_context.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 54c50bcf76b..99e439ad96d 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -357,19 +357,24 @@ def get_info(self, model_name: str, base_model: BaseModelType, model_type: Model class ConfigInterface(InvocationContextInterface): + def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None: + super().__init__(services, context_data) + # Config cache, only populated at runtime if requested + self._frozen_config: Optional[InvokeAIAppConfig] = None + def get(self) -> InvokeAIAppConfig: """ Gets the app's config. The config is read-only; attempts to mutate it will raise an error. """ - # The config can be changed at runtime. - # - # We don't want nodes doing this, so we make a frozen copy. + if self._frozen_config is None: + # The config is a live pydantic model and can be changed at runtime. + # We don't want nodes doing this, so we make a frozen copy. + self._frozen_config = self._services.configuration.get_config().model_copy( + update={"model_config": ConfigDict(frozen=True)} + ) - config = self._services.configuration.get_config() - # TODO(psyche): If config cannot be changed at runtime, should we cache this? - frozen_config = config.model_copy(update={"model_config": ConfigDict(frozen=True)}) - return frozen_config + return self._frozen_config class UtilInterface(InvocationContextInterface): From d7b7dcc7fefe30b7c2d4cca578eecbd4ce52de1b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 14:36:42 +1100 Subject: [PATCH 023/100] feat(nodes): context.__services -> context._services --- invokeai/app/services/shared/invocation_context.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 99e439ad96d..5da85931672 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -463,7 +463,8 @@ def __init__( """Provides methods to interact with boards.""" self.data = data """Provides data about the current queue item and invocation.""" - self.__services = services + self._services = services + """Provides access to the full application services. This is an internal API and may change without warning.""" @property @deprecated(version=deprecation_version, reason=get_deprecation_reason("`context.services`")) @@ -475,7 +476,7 @@ def services(self) -> InvocationServices: The invocation services. """ - return self.__services + return self._services @property @deprecated( From aff46759f9a8fffdb670347dab73ad68974b7bb4 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 14:39:26 +1100 Subject: [PATCH 024/100] feat(nodes): context.data -> context._data --- .../app/services/shared/invocation_context.py | 38 +++++++++---------- tests/aa_nodes/test_graph_execution_state.py | 2 +- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 5da85931672..b48a6acc545 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -442,7 +442,7 @@ def __init__( config: ConfigInterface, util: UtilInterface, boards: BoardsInterface, - data: InvocationContextData, + context_data: InvocationContextData, services: InvocationServices, ) -> None: self.images = images @@ -461,8 +461,8 @@ def __init__( """Provides utility methods.""" self.boards = boards """Provides methods to interact with boards.""" - self.data = data - """Provides data about the current queue item and invocation.""" + self._data = context_data + """Provides data about the current queue item and invocation. This is an internal API and may change without warning.""" self._services = services """Provides access to the full application services. This is an internal API and may change without warning.""" @@ -481,77 +481,77 @@ def services(self) -> InvocationServices: @property @deprecated( version=deprecation_version, - reason=get_deprecation_reason("`context.graph_execution_state_api`", "`context.data.session_id`"), + reason=get_deprecation_reason("`context.graph_execution_state_id", "`context._data.session_id`"), ) def graph_execution_state_id(self) -> str: """ **DEPRECATED as of v3.7.0** - `context.graph_execution_state_api` will be removed in v3.8.0. Use `context.data.session_id` instead. See PLACEHOLDER_URL for details. + `context.graph_execution_state_api` will be removed in v3.8.0. Use `context._data.session_id` instead. See PLACEHOLDER_URL for details. The ID of the session (aka graph execution state). """ - return self.data.session_id + return self._data.session_id @property @deprecated( version=deprecation_version, - reason=get_deprecation_reason("`context.queue_id`", "`context.data.queue_id`"), + reason=get_deprecation_reason("`context.queue_id`", "`context._data.queue_id`"), ) def queue_id(self) -> str: """ **DEPRECATED as of v3.7.0** - `context.queue_id` will be removed in v3.8.0. Use `context.data.queue_id` instead. See PLACEHOLDER_URL for details. + `context.queue_id` will be removed in v3.8.0. Use `context._data.queue_id` instead. See PLACEHOLDER_URL for details. The ID of the queue. """ - return self.data.queue_id + return self._data.queue_id @property @deprecated( version=deprecation_version, - reason=get_deprecation_reason("`context.queue_item_id`", "`context.data.queue_item_id`"), + reason=get_deprecation_reason("`context.queue_item_id`", "`context._data.queue_item_id`"), ) def queue_item_id(self) -> int: """ **DEPRECATED as of v3.7.0** - `context.queue_item_id` will be removed in v3.8.0. Use `context.data.queue_item_id` instead. See PLACEHOLDER_URL for details. + `context.queue_item_id` will be removed in v3.8.0. Use `context._data.queue_item_id` instead. See PLACEHOLDER_URL for details. The ID of the queue item. """ - return self.data.queue_item_id + return self._data.queue_item_id @property @deprecated( version=deprecation_version, - reason=get_deprecation_reason("`context.queue_batch_id`", "`context.data.batch_id`"), + reason=get_deprecation_reason("`context.queue_batch_id`", "`context._data.batch_id`"), ) def queue_batch_id(self) -> str: """ **DEPRECATED as of v3.7.0** - `context.queue_batch_id` will be removed in v3.8.0. Use `context.data.batch_id` instead. See PLACEHOLDER_URL for details. + `context.queue_batch_id` will be removed in v3.8.0. Use `context._data.batch_id` instead. See PLACEHOLDER_URL for details. The ID of the batch. """ - return self.data.batch_id + return self._data.batch_id @property @deprecated( version=deprecation_version, - reason=get_deprecation_reason("`context.workflow`", "`context.data.workflow`"), + reason=get_deprecation_reason("`context.workflow`", "`context._data.workflow`"), ) def workflow(self) -> Optional[WorkflowWithoutID]: """ **DEPRECATED as of v3.7.0** - `context.workflow` will be removed in v3.8.0. Use `context.data.workflow` instead. See PLACEHOLDER_URL for details. + `context.workflow` will be removed in v3.8.0. Use `context._data.workflow` instead. See PLACEHOLDER_URL for details. The workflow associated with this queue item, if any. """ - return self.data.workflow + return self._data.workflow def build_invocation_context( @@ -580,7 +580,7 @@ def build_invocation_context( config=config, latents=latents, models=models, - data=context_data, + context_data=context_data, util=util, conditioning=conditioning, services=services, diff --git a/tests/aa_nodes/test_graph_execution_state.py b/tests/aa_nodes/test_graph_execution_state.py index 1612cbe7198..aba7c5694f3 100644 --- a/tests/aa_nodes/test_graph_execution_state.py +++ b/tests/aa_nodes/test_graph_execution_state.py @@ -87,7 +87,7 @@ def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[B InvocationContext( conditioning=None, config=None, - data=None, + context_data=None, images=None, latents=None, logger=None, From b1ba18b3d1c70e17e72afc2f7d244dcc2cea3797 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 15:58:46 +1100 Subject: [PATCH 025/100] fix(nodes): do not freeze or cache config in context wrapper - The config is already cached by the config class's `get_config()` method. - The config mutates itself in its `root_path` property getter. Freezing the class makes any attempt to grab a path from the config error. Unfortunately this means we cannot easily freeze the class without fiddling with the inner workings of `InvokeAIAppConfig`, which is outside the scope here. --- .../app/services/shared/invocation_context.py | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index b48a6acc545..cd88ec876dd 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -357,24 +357,10 @@ def get_info(self, model_name: str, base_model: BaseModelType, model_type: Model class ConfigInterface(InvocationContextInterface): - def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None: - super().__init__(services, context_data) - # Config cache, only populated at runtime if requested - self._frozen_config: Optional[InvokeAIAppConfig] = None - def get(self) -> InvokeAIAppConfig: - """ - Gets the app's config. The config is read-only; attempts to mutate it will raise an error. - """ - - if self._frozen_config is None: - # The config is a live pydantic model and can be changed at runtime. - # We don't want nodes doing this, so we make a frozen copy. - self._frozen_config = self._services.configuration.get_config().model_copy( - update={"model_config": ConfigDict(frozen=True)} - ) + """Gets the app's config.""" - return self._frozen_config + return self._services.configuration.get_config() class UtilInterface(InvocationContextInterface): From e36d925bce497d2c5d2f29fe6b3e2685c4197125 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 16:14:35 +1100 Subject: [PATCH 026/100] fix(ui): remove original l2i node in HRF graph --- .../web/src/features/nodes/util/graph/addHrfToGraph.ts | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addHrfToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addHrfToGraph.ts index 7413302fa57..8a4448833cf 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addHrfToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addHrfToGraph.ts @@ -314,6 +314,10 @@ export const addHrfToGraph = (state: RootState, graph: NonNullableGraph): void = ); copyConnectionsToDenoiseLatentsHrf(graph); + // The original l2i node is unnecessary now, remove it + graph.edges = graph.edges.filter((edge) => edge.destination.node_id !== LATENTS_TO_IMAGE); + delete graph.nodes[LATENTS_TO_IMAGE]; + graph.nodes[LATENTS_TO_IMAGE_HRF_HR] = { type: 'l2i', id: LATENTS_TO_IMAGE_HRF_HR, From 1a191c4655e0fa59500f84dd96a23f9beca83ee4 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 16:23:57 +1100 Subject: [PATCH 027/100] remove unused configdict import --- invokeai/app/services/shared/invocation_context.py | 1 - 1 file changed, 1 deletion(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index cd88ec876dd..8aaa5233afd 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -3,7 +3,6 @@ from deprecated import deprecated from PIL.Image import Image -from pydantic import ConfigDict from torch import Tensor from invokeai.app.invocations.fields import MetadataField, WithMetadata From c16eba78ab630b4e7f2eec74285c21509a81ca37 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 16:33:55 +1100 Subject: [PATCH 028/100] feat(nodes): add `WithBoard` field helper class This class works the same way as `WithMetadata` - it simply adds a `board` field to the node. The context wrapper function is able to pull the board id from this. This allows image-outputting nodes to get a board field "for free", and have their outputs automatically saved to it. This is a breaking change for node authors who may have a field called `board`, because it makes `board` a reserved field name. I'll look into how to avoid this - maybe by naming this invoke-managed field `_board` to avoid collisions? Supporting changes: - `WithBoard` is added to all image-outputting nodes, giving them the ability to save to board. - Unused, duplicate `WithMetadata` and `WithWorkflow` classes are deleted from `baseinvocation.py`. The "real" versions are in `fields.py`. - Remove `LinearUIOutputInvocation`. Now that all nodes that output images also have a `board` field by default, this node is no longer necessary. See comment here for context: https://github.com/invoke-ai/InvokeAI/pull/5491#discussion_r1480760629 - Without `LinearUIOutputInvocation`, the `ImagesInferface.update` method is no longer needed, and removed. Note: This commit does not bump all node versions. I will ensure that is done correctly before merging the PR of which this commit is a part. Note: A followup commit will implement the frontend changes to support this change. --- invokeai/app/invocations/baseinvocation.py | 33 +------- .../controlnet_image_processors.py | 12 ++- invokeai/app/invocations/cv.py | 4 +- invokeai/app/invocations/facetools.py | 4 +- invokeai/app/invocations/fields.py | 16 ++++ invokeai/app/invocations/image.py | 76 ++++++------------- invokeai/app/invocations/infill.py | 12 +-- invokeai/app/invocations/latent.py | 3 +- invokeai/app/invocations/primitives.py | 4 +- invokeai/app/invocations/tiles.py | 4 +- invokeai/app/invocations/upscale.py | 4 +- .../app/services/shared/invocation_context.py | 40 +++------- 12 files changed, 78 insertions(+), 134 deletions(-) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index df0596c9a15..3243714937f 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -17,11 +17,8 @@ from pydantic_core import PydanticUndefined from invokeai.app.invocations.fields import ( - FieldDescriptions, FieldKind, Input, - InputFieldJSONSchemaExtra, - MetadataField, ) from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.shared.invocation_context import InvocationContext @@ -306,9 +303,7 @@ def invoke_internal(self, context: InvocationContext, services: "InvocationServi "workflow", } -RESERVED_INPUT_FIELD_NAMES = { - "metadata", -} +RESERVED_INPUT_FIELD_NAMES = {"metadata", "board"} RESERVED_OUTPUT_FIELD_NAMES = {"type"} @@ -518,29 +513,3 @@ def wrapper(cls: Type[TBaseInvocationOutput]) -> Type[TBaseInvocationOutput]: return cls return wrapper - - -class WithMetadata(BaseModel): - """ - Inherit from this class if your node needs a metadata input field. - """ - - metadata: Optional[MetadataField] = Field( - default=None, - description=FieldDescriptions.metadata, - json_schema_extra=InputFieldJSONSchemaExtra( - field_kind=FieldKind.Internal, - input=Input.Connection, - orig_required=False, - ).model_dump(exclude_none=True), - ) - - -class WithWorkflow: - workflow = None - - def __init_subclass__(cls) -> None: - logger.warn( - f"{cls.__module__.split('.')[0]}.{cls.__name__}: WithWorkflow is deprecated. Use `context.workflow` to access the workflow." - ) - super().__init_subclass__() diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index f8bdf14117c..37954c1097e 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -25,7 +25,15 @@ from PIL import Image from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator -from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, WithMetadata +from invokeai.app.invocations.fields import ( + FieldDescriptions, + ImageField, + Input, + InputField, + OutputField, + WithBoard, + WithMetadata, +) from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.services.shared.invocation_context import InvocationContext @@ -135,7 +143,7 @@ def invoke(self, context: InvocationContext) -> ControlOutput: # This invocation exists for other invocations to subclass it - do not register with @invocation! -class ImageProcessorInvocation(BaseInvocation, WithMetadata): +class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard): """Base class for invocations that preprocess images for ControlNet""" image: ImageField = InputField(description="The image to process") diff --git a/invokeai/app/invocations/cv.py b/invokeai/app/invocations/cv.py index 1ebabf5e064..8174f19b64c 100644 --- a/invokeai/app/invocations/cv.py +++ b/invokeai/app/invocations/cv.py @@ -10,11 +10,11 @@ from invokeai.app.services.shared.invocation_context import InvocationContext from .baseinvocation import BaseInvocation, invocation -from .fields import InputField, WithMetadata +from .fields import InputField, WithBoard, WithMetadata @invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.2.1") -class CvInpaintInvocation(BaseInvocation, WithMetadata): +class CvInpaintInvocation(BaseInvocation, WithMetadata, WithBoard): """Simple inpaint using opencv.""" image: ImageField = InputField(description="The image to inpaint") diff --git a/invokeai/app/invocations/facetools.py b/invokeai/app/invocations/facetools.py index a1702d6517c..fed2ed5e4f2 100644 --- a/invokeai/app/invocations/facetools.py +++ b/invokeai/app/invocations/facetools.py @@ -16,7 +16,7 @@ invocation, invocation_output, ) -from invokeai.app.invocations.fields import ImageField, InputField, OutputField, WithMetadata +from invokeai.app.invocations.fields import ImageField, InputField, OutputField, WithBoard, WithMetadata from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory from invokeai.app.services.shared.invocation_context import InvocationContext @@ -619,7 +619,7 @@ def invoke(self, context: InvocationContext) -> FaceMaskOutput: @invocation( "face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.2.1" ) -class FaceIdentifierInvocation(BaseInvocation, WithMetadata): +class FaceIdentifierInvocation(BaseInvocation, WithMetadata, WithBoard): """Outputs an image with detected face IDs printed on each face. For use with other FaceTools.""" image: ImageField = InputField(description="Image to face detect") diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index 8879f760770..c42d2f83120 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -280,6 +280,22 @@ def __init_subclass__(cls) -> None: super().__init_subclass__() +class WithBoard(BaseModel): + """ + Inherit from this class if your node needs a board input field. + """ + + board: Optional["BoardField"] = Field( + default=None, + description=FieldDescriptions.board, + json_schema_extra=InputFieldJSONSchemaExtra( + field_kind=FieldKind.Internal, + input=Input.Direct, + orig_required=False, + ).model_dump(exclude_none=True), + ) + + class OutputFieldJSONSchemaExtra(BaseModel): """ Extra attributes to be added to input fields and their OpenAPI schema. Used by the workflow editor diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 7b74e4d96d4..f5ad5515a68 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -8,12 +8,11 @@ from PIL import Image, ImageChops, ImageFilter, ImageOps from invokeai.app.invocations.fields import ( - BoardField, ColorField, FieldDescriptions, ImageField, - Input, InputField, + WithBoard, WithMetadata, ) from invokeai.app.invocations.primitives import ImageOutput @@ -55,7 +54,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class BlankImageInvocation(BaseInvocation, WithMetadata): +class BlankImageInvocation(BaseInvocation, WithMetadata, WithBoard): """Creates a blank image and forwards it to the pipeline""" width: int = InputField(default=512, description="The width of the image") @@ -78,7 +77,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImageCropInvocation(BaseInvocation, WithMetadata): +class ImageCropInvocation(BaseInvocation, WithMetadata, WithBoard): """Crops an image to a specified box. The box can be outside of the image.""" image: ImageField = InputField(description="The image to crop") @@ -149,7 +148,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImagePasteInvocation(BaseInvocation, WithMetadata): +class ImagePasteInvocation(BaseInvocation, WithMetadata, WithBoard): """Pastes an image into another image.""" base_image: ImageField = InputField(description="The base image") @@ -196,7 +195,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class MaskFromAlphaInvocation(BaseInvocation, WithMetadata): +class MaskFromAlphaInvocation(BaseInvocation, WithMetadata, WithBoard): """Extracts the alpha channel of an image as a mask.""" image: ImageField = InputField(description="The image to create the mask from") @@ -221,7 +220,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImageMultiplyInvocation(BaseInvocation, WithMetadata): +class ImageMultiplyInvocation(BaseInvocation, WithMetadata, WithBoard): """Multiplies two images together using `PIL.ImageChops.multiply()`.""" image1: ImageField = InputField(description="The first image to multiply") @@ -248,7 +247,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImageChannelInvocation(BaseInvocation, WithMetadata): +class ImageChannelInvocation(BaseInvocation, WithMetadata, WithBoard): """Gets a channel from an image.""" image: ImageField = InputField(description="The image to get the channel from") @@ -274,7 +273,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImageConvertInvocation(BaseInvocation, WithMetadata): +class ImageConvertInvocation(BaseInvocation, WithMetadata, WithBoard): """Converts an image to a different mode.""" image: ImageField = InputField(description="The image to convert") @@ -297,7 +296,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImageBlurInvocation(BaseInvocation, WithMetadata): +class ImageBlurInvocation(BaseInvocation, WithMetadata, WithBoard): """Blurs an image""" image: ImageField = InputField(description="The image to blur") @@ -326,7 +325,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: version="1.2.1", classification=Classification.Beta, ) -class UnsharpMaskInvocation(BaseInvocation, WithMetadata): +class UnsharpMaskInvocation(BaseInvocation, WithMetadata, WithBoard): """Applies an unsharp mask filter to an image""" image: ImageField = InputField(description="The image to use") @@ -394,7 +393,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImageResizeInvocation(BaseInvocation, WithMetadata): +class ImageResizeInvocation(BaseInvocation, WithMetadata, WithBoard): """Resizes an image to specific dimensions""" image: ImageField = InputField(description="The image to resize") @@ -424,7 +423,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImageScaleInvocation(BaseInvocation, WithMetadata): +class ImageScaleInvocation(BaseInvocation, WithMetadata, WithBoard): """Scales an image by a factor""" image: ImageField = InputField(description="The image to scale") @@ -459,7 +458,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImageLerpInvocation(BaseInvocation, WithMetadata): +class ImageLerpInvocation(BaseInvocation, WithMetadata, WithBoard): """Linear interpolation of all pixels of an image""" image: ImageField = InputField(description="The image to lerp") @@ -486,7 +485,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImageInverseLerpInvocation(BaseInvocation, WithMetadata): +class ImageInverseLerpInvocation(BaseInvocation, WithMetadata, WithBoard): """Inverse linear interpolation of all pixels of an image""" image: ImageField = InputField(description="The image to lerp") @@ -513,7 +512,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata): +class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithBoard): """Add blur to NSFW-flagged images""" image: ImageField = InputField(description="The image to check") @@ -548,7 +547,7 @@ def _get_caution_img(self) -> Image.Image: category="image", version="1.2.1", ) -class ImageWatermarkInvocation(BaseInvocation, WithMetadata): +class ImageWatermarkInvocation(BaseInvocation, WithMetadata, WithBoard): """Add an invisible watermark to an image""" image: ImageField = InputField(description="The image to check") @@ -569,7 +568,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class MaskEdgeInvocation(BaseInvocation, WithMetadata): +class MaskEdgeInvocation(BaseInvocation, WithMetadata, WithBoard): """Applies an edge mask to an image""" image: ImageField = InputField(description="The image to apply the mask to") @@ -608,7 +607,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class MaskCombineInvocation(BaseInvocation, WithMetadata): +class MaskCombineInvocation(BaseInvocation, WithMetadata, WithBoard): """Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`.""" mask1: ImageField = InputField(description="The first mask to combine") @@ -632,7 +631,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ColorCorrectInvocation(BaseInvocation, WithMetadata): +class ColorCorrectInvocation(BaseInvocation, WithMetadata, WithBoard): """ Shifts the colors of a target image to match the reference image, optionally using a mask to only color-correct certain regions of the target image. @@ -736,7 +735,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImageHueAdjustmentInvocation(BaseInvocation, WithMetadata): +class ImageHueAdjustmentInvocation(BaseInvocation, WithMetadata, WithBoard): """Adjusts the Hue of an image.""" image: ImageField = InputField(description="The image to adjust") @@ -825,7 +824,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata): +class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata, WithBoard): """Add or subtract a value from a specific color channel of an image.""" image: ImageField = InputField(description="The image to adjust") @@ -881,7 +880,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata): +class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata, WithBoard): """Scale a specific color channel of an image.""" image: ImageField = InputField(description="The image to adjust") @@ -926,41 +925,14 @@ def invoke(self, context: InvocationContext) -> ImageOutput: version="1.2.1", use_cache=False, ) -class SaveImageInvocation(BaseInvocation, WithMetadata): +class SaveImageInvocation(BaseInvocation, WithMetadata, WithBoard): """Saves an image. Unlike an image primitive, this invocation stores a copy of the image.""" image: ImageField = InputField(description=FieldDescriptions.image) - board: BoardField = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct) def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) - image_dto = context.images.save(image=image, board_id=self.board.board_id if self.board else None) - - return ImageOutput.build(image_dto) - - -@invocation( - "linear_ui_output", - title="Linear UI Image Output", - tags=["primitives", "image"], - category="primitives", - version="1.0.2", - use_cache=False, -) -class LinearUIOutputInvocation(BaseInvocation, WithMetadata): - """Handles Linear UI Image Outputting tasks.""" - - image: ImageField = InputField(description=FieldDescriptions.image) - board: Optional[BoardField] = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct) - - def invoke(self, context: InvocationContext) -> ImageOutput: - image_dto = context.images.get_dto(self.image.image_name) - - image_dto = context.images.update( - image_name=self.image.image_name, - board_id=self.board.board_id if self.board else None, - is_intermediate=self.is_intermediate, - ) + image_dto = context.images.save(image=image) return ImageOutput.build(image_dto) diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index b007edd9e42..53f6f4732fe 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -15,7 +15,7 @@ from invokeai.backend.image_util.patchmatch import PatchMatch from .baseinvocation import BaseInvocation, invocation -from .fields import InputField, WithMetadata +from .fields import InputField, WithBoard, WithMetadata from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES @@ -121,7 +121,7 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int] @invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1") -class InfillColorInvocation(BaseInvocation, WithMetadata): +class InfillColorInvocation(BaseInvocation, WithMetadata, WithBoard): """Infills transparent areas of an image with a solid color""" image: ImageField = InputField(description="The image to infill") @@ -144,7 +144,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: @invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2") -class InfillTileInvocation(BaseInvocation, WithMetadata): +class InfillTileInvocation(BaseInvocation, WithMetadata, WithBoard): """Infills transparent areas of an image with tiles of the image""" image: ImageField = InputField(description="The image to infill") @@ -170,7 +170,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: @invocation( "infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1" ) -class InfillPatchMatchInvocation(BaseInvocation, WithMetadata): +class InfillPatchMatchInvocation(BaseInvocation, WithMetadata, WithBoard): """Infills transparent areas of an image using the PatchMatch algorithm""" image: ImageField = InputField(description="The image to infill") @@ -209,7 +209,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: @invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1") -class LaMaInfillInvocation(BaseInvocation, WithMetadata): +class LaMaInfillInvocation(BaseInvocation, WithMetadata, WithBoard): """Infills transparent areas of an image using the LaMa model""" image: ImageField = InputField(description="The image to infill") @@ -225,7 +225,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: @invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1") -class CV2InfillInvocation(BaseInvocation, WithMetadata): +class CV2InfillInvocation(BaseInvocation, WithMetadata, WithBoard): """Infills transparent areas of an image using OpenCV Inpainting""" image: ImageField = InputField(description="The image to infill") diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 5e36e73ec8f..5449ec9af7a 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -33,6 +33,7 @@ LatentsField, OutputField, UIType, + WithBoard, WithMetadata, ) from invokeai.app.invocations.ip_adapter import IPAdapterField @@ -762,7 +763,7 @@ def _lora_loader(): category="latents", version="1.2.1", ) -class LatentsToImageInvocation(BaseInvocation, WithMetadata): +class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): """Generates an image from latents.""" latents: LatentsField = InputField( diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index c90d3230b2b..a77939943ae 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -255,9 +255,7 @@ class ImageCollectionOutput(BaseInvocationOutput): @invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives", version="1.0.1") -class ImageInvocation( - BaseInvocation, -): +class ImageInvocation(BaseInvocation): """An image primitive value""" image: ImageField = InputField(description="The image to load") diff --git a/invokeai/app/invocations/tiles.py b/invokeai/app/invocations/tiles.py index 19ece423761..cb5373bbf75 100644 --- a/invokeai/app/invocations/tiles.py +++ b/invokeai/app/invocations/tiles.py @@ -11,7 +11,7 @@ invocation, invocation_output, ) -from invokeai.app.invocations.fields import ImageField, Input, InputField, OutputField, WithMetadata +from invokeai.app.invocations.fields import ImageField, Input, InputField, OutputField, WithBoard, WithMetadata from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.tiles.tiles import ( @@ -232,7 +232,7 @@ def invoke(self, context: InvocationContext) -> PairTileImageOutput: version="1.1.0", classification=Classification.Beta, ) -class MergeTilesToImageInvocation(BaseInvocation, WithMetadata): +class MergeTilesToImageInvocation(BaseInvocation, WithMetadata, WithBoard): """Merge multiple tile images into a single image.""" # Inputs diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index 71ef7ca3aa0..2e2a6ce8813 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -16,7 +16,7 @@ from invokeai.backend.util.devices import choose_torch_device from .baseinvocation import BaseInvocation, invocation -from .fields import InputField, WithMetadata +from .fields import InputField, WithBoard, WithMetadata # TODO: Populate this from disk? # TODO: Use model manager to load? @@ -32,7 +32,7 @@ @invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.1") -class ESRGANInvocation(BaseInvocation, WithMetadata): +class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard): """Upscales an image using RealESRGAN.""" image: ImageField = InputField(description="The input image") diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 8aaa5233afd..97a62246fbc 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -5,10 +5,10 @@ from PIL.Image import Image from torch import Tensor -from invokeai.app.invocations.fields import MetadataField, WithMetadata +from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata from invokeai.app.services.boards.boards_common import BoardDTO from invokeai.app.services.config.config_default import InvokeAIAppConfig -from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin +from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin from invokeai.app.services.images.images_common import ImageDTO from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID @@ -158,7 +158,9 @@ def save( If the current queue item has a workflow or metadata, it is automatically saved with the image. :param image: The image to save, as a PIL image. - :param board_id: The board ID to add the image to, if it should be added. + :param board_id: The board ID to add the image to, if it should be added. It the invocation \ + inherits from `WithBoard`, that board will be used automatically. **Use this only if \ + you want to override or provide a board manually!** :param image_category: The category of the image. Only the GENERAL category is added \ to the gallery. :param metadata: The metadata to save with the image, if it should have any. If the \ @@ -173,11 +175,15 @@ def save( else metadata ) + # If the invocation inherits WithBoard, use that. Else, use the board_id passed in. + board_ = self._context_data.invocation.board if isinstance(self._context_data.invocation, WithBoard) else None + board_id_ = board_.board_id if board_ is not None else board_id + return self._services.images.create( image=image, is_intermediate=self._context_data.invocation.is_intermediate, image_category=image_category, - board_id=board_id, + board_id=board_id_, metadata=metadata_, image_origin=ResourceOrigin.INTERNAL, workflow=self._context_data.workflow, @@ -209,32 +215,6 @@ def get_dto(self, image_name: str) -> ImageDTO: """ return self._services.images.get_dto(image_name) - def update( - self, - image_name: str, - board_id: Optional[str] = None, - is_intermediate: Optional[bool] = False, - ) -> ImageDTO: - """ - Updates an image, returning its updated DTO. - - It is not suggested to update images saved by earlier nodes, as this can cause confusion for users. - - If you use this method, you *must* return the image as an :class:`ImageOutput` for the gallery to - get the updated image. - - :param image_name: The name of the image to update. - :param board_id: The board ID to add the image to, if it should be added. - :param is_intermediate: Whether the image is an intermediate. Intermediate images aren't added to the gallery. - """ - if is_intermediate is not None: - self._services.images.update(image_name, ImageRecordChanges(is_intermediate=is_intermediate)) - if board_id is None: - self._services.board_images.remove_image_from_board(image_name) - else: - self._services.board_images.add_image_to_board(image_name, board_id) - return self._services.images.get_dto(image_name) - class LatentsInterface(InvocationContextInterface): def save(self, tensor: Tensor) -> str: From a3faa3792a91a8898674b160aea12e768bab2ac3 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 16:34:40 +1100 Subject: [PATCH 029/100] chore(ui): regen types --- .../frontend/web/src/services/api/schema.ts | 223 ++++++++---------- 1 file changed, 96 insertions(+), 127 deletions(-) diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index da036b6d40a..45358ed97d5 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -968,6 +968,8 @@ export type components = { * @description Creates a blank image and forwards it to the pipeline */ BlankImageInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -1860,6 +1862,8 @@ export type components = { * @description Infills transparent areas of an image using OpenCV Inpainting */ CV2InfillInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -2095,6 +2099,8 @@ export type components = { * @description Canny edge detection for ControlNet */ CannyImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -2482,6 +2488,8 @@ export type components = { * using a mask to only color-correct certain regions of the target image. */ ColorCorrectInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -2590,6 +2598,8 @@ export type components = { * @description Generates a color map from the provided image */ ColorMapImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -2797,6 +2807,8 @@ export type components = { * @description Applies content shuffle processing to image */ ContentShuffleImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -3442,6 +3454,8 @@ export type components = { * @description Simple inpaint using opencv. */ CvInpaintInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -3677,6 +3691,8 @@ export type components = { * @description Generates a depth map based on the Depth Anything algorithm */ DepthAnythingImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -3910,6 +3926,8 @@ export type components = { * @description Upscales an image using RealESRGAN. */ ESRGANInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -4041,6 +4059,8 @@ export type components = { * @description Outputs an image with detected face IDs printed on each face. For use with other FaceTools. */ FaceIdentifierInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -4873,6 +4893,8 @@ export type components = { * @description Applies HED edge detection to image */ HedImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -5324,6 +5346,8 @@ export type components = { * @description Blurs an image */ ImageBlurInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -5382,6 +5406,8 @@ export type components = { * @description Gets a channel from an image. */ ImageChannelInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -5422,6 +5448,8 @@ export type components = { * @description Scale a specific color channel of an image. */ ImageChannelMultiplyInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -5473,6 +5501,8 @@ export type components = { * @description Add or subtract a value from a specific color channel of an image. */ ImageChannelOffsetInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -5640,6 +5670,8 @@ export type components = { * @description Converts an image to a different mode. */ ImageConvertInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -5680,6 +5712,8 @@ export type components = { * @description Crops an image to a specified box. The box can be outside of the image. */ ImageCropInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -5949,6 +5983,8 @@ export type components = { * @description Adjusts the Hue of an image. */ ImageHueAdjustmentInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -5988,6 +6024,8 @@ export type components = { * @description Inverse linear interpolation of all pixels of an image */ ImageInverseLerpInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6064,6 +6102,8 @@ export type components = { * @description Linear interpolation of all pixels of an image */ ImageLerpInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6109,6 +6149,8 @@ export type components = { * @description Multiplies two images together using `PIL.ImageChops.multiply()`. */ ImageMultiplyInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6144,6 +6186,8 @@ export type components = { * @description Add blur to NSFW-flagged images */ ImageNSFWBlurInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6252,6 +6296,8 @@ export type components = { * @description Pastes an image into another image. */ ImagePasteInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6337,6 +6383,8 @@ export type components = { * @description Resizes an image to specific dimensions */ ImageResizeInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6446,6 +6494,8 @@ export type components = { * @description Scales an image by a factor */ ImageScaleInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6621,6 +6671,8 @@ export type components = { * @description Add an invisible watermark to an image */ ImageWatermarkInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6676,6 +6728,8 @@ export type components = { * @description Infills transparent areas of an image with a solid color */ InfillColorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6719,6 +6773,8 @@ export type components = { * @description Infills transparent areas of an image using the PatchMatch algorithm */ InfillPatchMatchInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6765,6 +6821,8 @@ export type components = { * @description Infills transparent areas of an image with tiles of the image */ InfillTileInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -7065,6 +7123,8 @@ export type components = { * @description Infills transparent areas of an image using the LaMa model */ LaMaInfillInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -7093,96 +7153,6 @@ export type components = { */ type: "infill_lama"; }; - /** - * Latent Consistency MonoNode - * @description Wrapper node around diffusers LatentConsistencyTxt2ImgPipeline - */ - LatentConsistencyInvocation: { - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** - * Prompt - * @description The prompt to use - */ - prompt?: string; - /** - * Num Inference Steps - * @description The number of inference steps to use, 4-8 recommended - * @default 8 - */ - num_inference_steps?: number; - /** - * Guidance Scale - * @description The guidance scale to use - * @default 8 - */ - guidance_scale?: number; - /** - * Batches - * @description The number of batches to use - * @default 1 - */ - batches?: number; - /** - * Images Per Batch - * @description The number of images per batch to use - * @default 1 - */ - images_per_batch?: number; - /** - * Seeds - * @description List of noise seeds to use - */ - seeds?: number[]; - /** - * Lcm Origin Steps - * @description The lcm origin steps to use - * @default 50 - */ - lcm_origin_steps?: number; - /** - * Width - * @description The width to use - * @default 512 - */ - width?: number; - /** - * Height - * @description The height to use - * @default 512 - */ - height?: number; - /** - * Precision - * @description floating point precision - * @default fp16 - * @enum {string} - */ - precision?: "fp16" | "fp32"; - /** @description The board to save the image to */ - board?: components["schemas"]["BoardField"]; - /** - * type - * @default latent_consistency_mononode - * @constant - */ - type: "latent_consistency_mononode"; - }; /** * Latents Collection Primitive * @description A collection of latents tensor primitive values @@ -7310,6 +7280,8 @@ export type components = { * @description Generates an image from latents. */ LatentsToImageInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -7357,6 +7329,8 @@ export type components = { * @description Applies leres processing to image */ LeresImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -7441,46 +7415,13 @@ export type components = { /** @description Type of commercial use allowed or 'No' if no commercial use is allowed. */ AllowCommercialUse?: components["schemas"]["CommercialUsage"]; }; - /** - * Linear UI Image Output - * @description Handles Linear UI Image Outputting tasks. - */ - LinearUIOutputInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default false - */ - use_cache?: boolean; - /** @description The image to process */ - image?: components["schemas"]["ImageField"]; - /** @description The board to save the image to */ - board?: components["schemas"]["BoardField"] | null; - /** - * type - * @default linear_ui_output - * @constant - */ - type: "linear_ui_output"; - }; /** * Lineart Anime Processor * @description Applies line art anime processing to image */ LineartAnimeImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -7526,6 +7467,8 @@ export type components = { * @description Applies line art processing to image */ LineartImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -7954,6 +7897,8 @@ export type components = { * @description Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`. */ MaskCombineInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -7989,6 +7934,8 @@ export type components = { * @description Applies an edge mask to an image */ MaskEdgeInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -8042,6 +7989,8 @@ export type components = { * @description Extracts the alpha channel of an image as a mask. */ MaskFromAlphaInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -8122,6 +8071,8 @@ export type components = { * @description Applies mediapipe face processing to image */ MediapipeFaceProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -8238,6 +8189,8 @@ export type components = { * @description Merge multiple tile images into a single image. */ MergeTilesToImageInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -8404,6 +8357,8 @@ export type components = { * @description Applies Midas depth processing to image */ MidasDepthImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -8449,6 +8404,8 @@ export type components = { * @description Applies MLSD processing to image */ MlsdImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -8961,6 +8918,8 @@ export type components = { * @description Applies NormalBae processing to image */ NormalbaeImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -9591,6 +9550,8 @@ export type components = { * @description Applies PIDI processing to image */ PidiImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -10443,6 +10404,8 @@ export type components = { * @description Saves an image. Unlike an image primitive, this invocation stores a copy of the image. */ SaveImageInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -10464,8 +10427,6 @@ export type components = { use_cache?: boolean; /** @description The image to process */ image?: components["schemas"]["ImageField"]; - /** @description The board to save the image to */ - board?: components["schemas"]["BoardField"]; /** * type * @default save_image @@ -10651,6 +10612,8 @@ export type components = { * @description Applies segment anything processing to image */ SegmentAnythingProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -12189,6 +12152,8 @@ export type components = { * @description Tile resampler processor */ TileResamplerProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -12378,6 +12343,8 @@ export type components = { * @description Applies an unsharp mask filter to an image */ UnsharpMaskInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -12846,6 +12813,8 @@ export type components = { * @description Applies Zoe depth processing to image */ ZoeDepthImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** From 8e2b61e19f4911d372d2971c2beb61c834ef5a29 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 16:41:24 +1100 Subject: [PATCH 030/100] feat(ui): revise graphs to not use `LinearUIOutputInvocation` See this comment for context: https://github.com/invoke-ai/InvokeAI/pull/5491#discussion_r1480760629 - Remove this now-unnecessary node from all graphs - Update graphs' terminal image-outputting nodes' `is_intermediate` and `board` fields appropriately - Add util function to prepare the `board` field, tidy the utils - Update `socketInvocationComplete` listener to work correctly with this change I've manually tested all graph permutations that were changed (I think this is all...) to ensure images go to the gallery as expected: - ad-hoc upscaling - t2i w/ sd1.5 - t2i w/ sd1.5 & hrf - t2i w/ sdxl - t2i w/ sdxl + refiner - i2i w/ sd1.5 - i2i w/ sdxl - i2i w/ sdxl + refiner - canvas t2i w/ sd1.5 - canvas t2i w/ sdxl - canvas t2i w/ sdxl + refiner - canvas i2i w/ sd1.5 - canvas i2i w/ sdxl - canvas i2i w/ sdxl + refiner - canvas inpaint w/ sd1.5 - canvas inpaint w/ sdxl - canvas inpaint w/ sdxl + refiner - canvas outpaint w/ sd1.5 - canvas outpaint w/ sdxl - canvas outpaint w/ sdxl + refiner --- .../socketio/socketInvocationComplete.ts | 9 +-- .../listeners/upscaleRequested.ts | 6 +- .../nodes/util/graph/addHrfToGraph.ts | 4 +- .../nodes/util/graph/addLinearUIOutputNode.ts | 78 ------------------- .../nodes/util/graph/addNSFWCheckerToGraph.ts | 4 +- .../nodes/util/graph/addSDXLRefinerToGraph.ts | 2 +- .../nodes/util/graph/addWatermarkerToGraph.ts | 9 +-- .../util/graph/buildAdHocUpscaleGraph.ts | 40 +++------- .../graph/buildCanvasImageToImageGraph.ts | 13 ++-- .../util/graph/buildCanvasInpaintGraph.ts | 7 +- .../util/graph/buildCanvasOutpaintGraph.ts | 7 +- .../graph/buildCanvasSDXLImageToImageGraph.ts | 8 +- .../util/graph/buildCanvasSDXLInpaintGraph.ts | 8 +- .../graph/buildCanvasSDXLOutpaintGraph.ts | 8 +- .../graph/buildCanvasSDXLTextToImageGraph.ts | 11 ++- .../util/graph/buildCanvasTextToImageGraph.ts | 10 +-- .../graph/buildLinearImageToImageGraph.ts | 7 +- .../graph/buildLinearSDXLImageToImageGraph.ts | 8 +- .../graph/buildLinearSDXLTextToImageGraph.ts | 8 +- .../util/graph/buildLinearTextToImageGraph.ts | 7 +- .../features/nodes/util/graph/constants.ts | 1 - .../nodes/util/graph/getSDXLStylePrompt.ts | 11 --- .../nodes/util/graph/graphBuilderUtils.ts | 38 +++++++++ .../frontend/web/src/services/api/types.ts | 1 - 24 files changed, 108 insertions(+), 197 deletions(-) delete mode 100644 invokeai/frontend/web/src/features/nodes/util/graph/addLinearUIOutputNode.ts delete mode 100644 invokeai/frontend/web/src/features/nodes/util/graph/getSDXLStylePrompt.ts create mode 100644 invokeai/frontend/web/src/features/nodes/util/graph/graphBuilderUtils.ts diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts index d49f35cd2ab..75fa9e10949 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts @@ -4,7 +4,7 @@ import { addImageToStagingArea } from 'features/canvas/store/canvasSlice'; import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice'; import { IMAGE_CATEGORIES } from 'features/gallery/store/types'; import { isImageOutput } from 'features/nodes/types/common'; -import { LINEAR_UI_OUTPUT, nodeIDDenyList } from 'features/nodes/util/graph/constants'; +import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants'; import { boardsApi } from 'services/api/endpoints/boards'; import { imagesApi } from 'services/api/endpoints/images'; import { imagesAdapter } from 'services/api/util'; @@ -24,10 +24,9 @@ export const addInvocationCompleteEventListener = () => { const { data } = action.payload; log.debug({ data: parseify(data) }, `Invocation complete (${action.payload.data.node.type})`); - const { result, node, queue_batch_id, source_node_id } = data; - + const { result, node, queue_batch_id } = data; // This complete event has an associated image output - if (isImageOutput(result) && !nodeTypeDenylist.includes(node.type) && !nodeIDDenyList.includes(source_node_id)) { + if (isImageOutput(result) && !nodeTypeDenylist.includes(node.type)) { const { image_name } = result.image; const { canvas, gallery } = getState(); @@ -42,7 +41,7 @@ export const addInvocationCompleteEventListener = () => { imageDTORequest.unsubscribe(); // Add canvas images to the staging area - if (canvas.batchIds.includes(queue_batch_id) && [LINEAR_UI_OUTPUT].includes(data.source_node_id)) { + if (canvas.batchIds.includes(queue_batch_id) && data.source_node_id === CANVAS_OUTPUT) { dispatch(addImageToStagingArea(imageDTO)); } diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/upscaleRequested.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/upscaleRequested.ts index 46f55ef21ff..ab989301796 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/upscaleRequested.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/upscaleRequested.ts @@ -39,16 +39,12 @@ export const addUpscaleRequestedListener = () => { return; } - const { esrganModelName } = state.postprocessing; - const { autoAddBoardId } = state.gallery; - const enqueueBatchArg: BatchConfig = { prepend: true, batch: { graph: buildAdHocUpscaleGraph({ image_name, - esrganModelName, - autoAddBoardId, + state, }), runs: 1, }, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addHrfToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addHrfToGraph.ts index 8a4448833cf..5632cfd1122 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addHrfToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addHrfToGraph.ts @@ -1,6 +1,7 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; import { roundToMultiple } from 'common/util/roundDownToMultiple'; +import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; import { selectOptimalDimension } from 'features/parameters/store/generationSlice'; import type { DenoiseLatentsInvocation, @@ -322,7 +323,8 @@ export const addHrfToGraph = (state: RootState, graph: NonNullableGraph): void = type: 'l2i', id: LATENTS_TO_IMAGE_HRF_HR, fp32: originalLatentsToImageNode?.fp32, - is_intermediate: true, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), }; graph.edges.push( { diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addLinearUIOutputNode.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addLinearUIOutputNode.ts deleted file mode 100644 index 5c78ad804ed..00000000000 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addLinearUIOutputNode.ts +++ /dev/null @@ -1,78 +0,0 @@ -import type { RootState } from 'app/store/store'; -import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; -import type { LinearUIOutputInvocation, NonNullableGraph } from 'services/api/types'; - -import { - CANVAS_OUTPUT, - LATENTS_TO_IMAGE, - LATENTS_TO_IMAGE_HRF_HR, - LINEAR_UI_OUTPUT, - NSFW_CHECKER, - WATERMARKER, -} from './constants'; - -/** - * Set the `use_cache` field on the linear/canvas graph's final image output node to False. - */ -export const addLinearUIOutputNode = (state: RootState, graph: NonNullableGraph): void => { - const activeTabName = activeTabNameSelector(state); - const is_intermediate = activeTabName === 'unifiedCanvas' ? !state.canvas.shouldAutoSave : false; - const { autoAddBoardId } = state.gallery; - - const linearUIOutputNode: LinearUIOutputInvocation = { - id: LINEAR_UI_OUTPUT, - type: 'linear_ui_output', - is_intermediate, - use_cache: false, - board: autoAddBoardId === 'none' ? undefined : { board_id: autoAddBoardId }, - }; - - graph.nodes[LINEAR_UI_OUTPUT] = linearUIOutputNode; - - const destination = { - node_id: LINEAR_UI_OUTPUT, - field: 'image', - }; - - if (WATERMARKER in graph.nodes) { - graph.edges.push({ - source: { - node_id: WATERMARKER, - field: 'image', - }, - destination, - }); - } else if (NSFW_CHECKER in graph.nodes) { - graph.edges.push({ - source: { - node_id: NSFW_CHECKER, - field: 'image', - }, - destination, - }); - } else if (CANVAS_OUTPUT in graph.nodes) { - graph.edges.push({ - source: { - node_id: CANVAS_OUTPUT, - field: 'image', - }, - destination, - }); - } else if (LATENTS_TO_IMAGE_HRF_HR in graph.nodes) { - graph.edges.push({ - source: { - node_id: LATENTS_TO_IMAGE_HRF_HR, - field: 'image', - }, - destination, - }); - } else if (LATENTS_TO_IMAGE in graph.nodes) { - graph.edges.push({ - source: { - node_id: LATENTS_TO_IMAGE, - field: 'image', - }, - destination, - }); - } -}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addNSFWCheckerToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addNSFWCheckerToGraph.ts index 4a8e77abfa0..35fc3246890 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addNSFWCheckerToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addNSFWCheckerToGraph.ts @@ -2,6 +2,7 @@ import type { RootState } from 'app/store/store'; import type { ImageNSFWBlurInvocation, LatentsToImageInvocation, NonNullableGraph } from 'services/api/types'; import { LATENTS_TO_IMAGE, NSFW_CHECKER } from './constants'; +import { getBoardField, getIsIntermediate } from './graphBuilderUtils'; export const addNSFWCheckerToGraph = ( state: RootState, @@ -21,7 +22,8 @@ export const addNSFWCheckerToGraph = ( const nsfwCheckerNode: ImageNSFWBlurInvocation = { id: NSFW_CHECKER, type: 'img_nsfw', - is_intermediate: true, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), }; graph.nodes[NSFW_CHECKER] = nsfwCheckerNode as ImageNSFWBlurInvocation; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLRefinerToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLRefinerToGraph.ts index 708353e4d6a..fc4d998969d 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLRefinerToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLRefinerToGraph.ts @@ -24,7 +24,7 @@ import { SDXL_REFINER_POSITIVE_CONDITIONING, SDXL_REFINER_SEAMLESS, } from './constants'; -import { getSDXLStylePrompts } from './getSDXLStylePrompt'; +import { getSDXLStylePrompts } from './graphBuilderUtils'; import { upsertMetadata } from './metadata'; export const addSDXLRefinerToGraph = ( diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addWatermarkerToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addWatermarkerToGraph.ts index 99c5c07be47..61beb11df49 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addWatermarkerToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addWatermarkerToGraph.ts @@ -1,5 +1,4 @@ import type { RootState } from 'app/store/store'; -import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import type { ImageNSFWBlurInvocation, ImageWatermarkInvocation, @@ -8,16 +7,13 @@ import type { } from 'services/api/types'; import { LATENTS_TO_IMAGE, NSFW_CHECKER, WATERMARKER } from './constants'; +import { getBoardField, getIsIntermediate } from './graphBuilderUtils'; export const addWatermarkerToGraph = ( state: RootState, graph: NonNullableGraph, nodeIdToAddTo = LATENTS_TO_IMAGE ): void => { - const activeTabName = activeTabNameSelector(state); - - const is_intermediate = activeTabName === 'unifiedCanvas' ? !state.canvas.shouldAutoSave : false; - const nodeToAddTo = graph.nodes[nodeIdToAddTo] as LatentsToImageInvocation | undefined; const nsfwCheckerNode = graph.nodes[NSFW_CHECKER] as ImageNSFWBlurInvocation | undefined; @@ -30,7 +26,8 @@ export const addWatermarkerToGraph = ( const watermarkerNode: ImageWatermarkInvocation = { id: WATERMARKER, type: 'img_watermark', - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), }; graph.nodes[WATERMARKER] = watermarkerNode; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildAdHocUpscaleGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildAdHocUpscaleGraph.ts index fa20206d91f..52c09b1db06 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildAdHocUpscaleGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildAdHocUpscaleGraph.ts @@ -1,51 +1,33 @@ -import type { BoardId } from 'features/gallery/store/types'; -import type { ParamESRGANModelName } from 'features/parameters/store/postprocessingSlice'; -import type { ESRGANInvocation, Graph, LinearUIOutputInvocation, NonNullableGraph } from 'services/api/types'; +import type { RootState } from 'app/store/store'; +import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; +import type { ESRGANInvocation, Graph, NonNullableGraph } from 'services/api/types'; -import { ESRGAN, LINEAR_UI_OUTPUT } from './constants'; +import { ESRGAN } from './constants'; import { addCoreMetadataNode, upsertMetadata } from './metadata'; type Arg = { image_name: string; - esrganModelName: ParamESRGANModelName; - autoAddBoardId: BoardId; + state: RootState; }; -export const buildAdHocUpscaleGraph = ({ image_name, esrganModelName, autoAddBoardId }: Arg): Graph => { +export const buildAdHocUpscaleGraph = ({ image_name, state }: Arg): Graph => { + const { esrganModelName } = state.postprocessing; + const realesrganNode: ESRGANInvocation = { id: ESRGAN, type: 'esrgan', image: { image_name }, model_name: esrganModelName, - is_intermediate: true, - }; - - const linearUIOutputNode: LinearUIOutputInvocation = { - id: LINEAR_UI_OUTPUT, - type: 'linear_ui_output', - use_cache: false, - is_intermediate: false, - board: autoAddBoardId === 'none' ? undefined : { board_id: autoAddBoardId }, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), }; const graph: NonNullableGraph = { id: `adhoc-esrgan-graph`, nodes: { [ESRGAN]: realesrganNode, - [LINEAR_UI_OUTPUT]: linearUIOutputNode, }, - edges: [ - { - source: { - node_id: ESRGAN, - field: 'image', - }, - destination: { - node_id: LINEAR_UI_OUTPUT, - field: 'image', - }, - }, - ], + edges: [], }; addCoreMetadataNode(graph, {}, ESRGAN); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasImageToImageGraph.ts index 3002e05441b..bc6a83f4fa6 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasImageToImageGraph.ts @@ -1,10 +1,10 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; +import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; import type { ImageDTO, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; @@ -132,7 +132,8 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima [CANVAS_OUTPUT]: { type: 'l2i', id: CANVAS_OUTPUT, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), use_cache: false, }, }, @@ -242,7 +243,8 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima graph.nodes[CANVAS_OUTPUT] = { id: CANVAS_OUTPUT, type: 'img_resize', - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), width: width, height: height, use_cache: false, @@ -284,7 +286,8 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima graph.nodes[CANVAS_OUTPUT] = { type: 'l2i', id: CANVAS_OUTPUT, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), fp32, use_cache: false, }; @@ -355,7 +358,5 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima addWatermarkerToGraph(state, graph, CANVAS_OUTPUT); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasInpaintGraph.ts index bb52a44a8e4..d983b9cf4f5 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasInpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasInpaintGraph.ts @@ -1,5 +1,6 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; +import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; import type { CreateDenoiseMaskInvocation, ImageBlurInvocation, @@ -12,7 +13,6 @@ import type { import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; @@ -191,7 +191,8 @@ export const buildCanvasInpaintGraph = ( [CANVAS_OUTPUT]: { type: 'color_correct', id: CANVAS_OUTPUT, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), reference: canvasInitImage, use_cache: false, }, @@ -663,7 +664,5 @@ export const buildCanvasInpaintGraph = ( addWatermarkerToGraph(state, graph, CANVAS_OUTPUT); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasOutpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasOutpaintGraph.ts index b82b55cfee7..1d028943818 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasOutpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasOutpaintGraph.ts @@ -1,5 +1,6 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; +import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; import type { ImageDTO, ImageToLatentsInvocation, @@ -11,7 +12,6 @@ import type { import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; @@ -200,7 +200,8 @@ export const buildCanvasOutpaintGraph = ( [CANVAS_OUTPUT]: { type: 'color_correct', id: CANVAS_OUTPUT, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), use_cache: false, }, }, @@ -769,7 +770,5 @@ export const buildCanvasOutpaintGraph = ( addWatermarkerToGraph(state, graph, CANVAS_OUTPUT); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLImageToImageGraph.ts index 1b586371a02..58269afce3f 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLImageToImageGraph.ts @@ -4,7 +4,6 @@ import type { ImageDTO, ImageToLatentsInvocation, NonNullableGraph } from 'servi import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; @@ -26,7 +25,7 @@ import { SDXL_REFINER_SEAMLESS, SEAMLESS, } from './constants'; -import { getSDXLStylePrompts } from './getSDXLStylePrompt'; +import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils'; import { addCoreMetadataNode } from './metadata'; /** @@ -246,7 +245,8 @@ export const buildCanvasSDXLImageToImageGraph = (state: RootState, initialImage: graph.nodes[CANVAS_OUTPUT] = { id: CANVAS_OUTPUT, type: 'img_resize', - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), width: width, height: height, use_cache: false, @@ -368,7 +368,5 @@ export const buildCanvasSDXLImageToImageGraph = (state: RootState, initialImage: addWatermarkerToGraph(state, graph, CANVAS_OUTPUT); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLInpaintGraph.ts index 00fea9a37e6..5902dee2fc4 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLInpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLInpaintGraph.ts @@ -12,7 +12,6 @@ import type { import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; @@ -44,7 +43,7 @@ import { SDXL_REFINER_SEAMLESS, SEAMLESS, } from './constants'; -import { getSDXLStylePrompts } from './getSDXLStylePrompt'; +import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils'; /** * Builds the Canvas tab's Inpaint graph. @@ -190,7 +189,8 @@ export const buildCanvasSDXLInpaintGraph = ( [CANVAS_OUTPUT]: { type: 'color_correct', id: CANVAS_OUTPUT, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), reference: canvasInitImage, use_cache: false, }, @@ -687,7 +687,5 @@ export const buildCanvasSDXLInpaintGraph = ( addWatermarkerToGraph(state, graph, CANVAS_OUTPUT); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLOutpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLOutpaintGraph.ts index f85760d8f2e..7a78750e8d2 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLOutpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLOutpaintGraph.ts @@ -11,7 +11,6 @@ import type { import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; @@ -46,7 +45,7 @@ import { SDXL_REFINER_SEAMLESS, SEAMLESS, } from './constants'; -import { getSDXLStylePrompts } from './getSDXLStylePrompt'; +import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils'; /** * Builds the Canvas tab's Outpaint graph. @@ -199,7 +198,8 @@ export const buildCanvasSDXLOutpaintGraph = ( [CANVAS_OUTPUT]: { type: 'color_correct', id: CANVAS_OUTPUT, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), use_cache: false, }, }, @@ -786,7 +786,5 @@ export const buildCanvasSDXLOutpaintGraph = ( addWatermarkerToGraph(state, graph, CANVAS_OUTPUT); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLTextToImageGraph.ts index 91d9da4cb5d..22da39c67da 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLTextToImageGraph.ts @@ -4,7 +4,6 @@ import type { NonNullableGraph } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; @@ -24,7 +23,7 @@ import { SDXL_REFINER_SEAMLESS, SEAMLESS, } from './constants'; -import { getSDXLStylePrompts } from './getSDXLStylePrompt'; +import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils'; import { addCoreMetadataNode } from './metadata'; /** @@ -222,7 +221,8 @@ export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGr graph.nodes[CANVAS_OUTPUT] = { id: CANVAS_OUTPUT, type: 'img_resize', - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), width: width, height: height, use_cache: false, @@ -254,7 +254,8 @@ export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGr graph.nodes[CANVAS_OUTPUT] = { type: 'l2i', id: CANVAS_OUTPUT, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), fp32, use_cache: false, }; @@ -330,7 +331,5 @@ export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGr addWatermarkerToGraph(state, graph, CANVAS_OUTPUT); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasTextToImageGraph.ts index 967dd3ff4a5..93f0470c7ad 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasTextToImageGraph.ts @@ -1,10 +1,10 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; +import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; import type { NonNullableGraph } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; @@ -211,7 +211,8 @@ export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph graph.nodes[CANVAS_OUTPUT] = { id: CANVAS_OUTPUT, type: 'img_resize', - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), width: width, height: height, use_cache: false, @@ -243,7 +244,8 @@ export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph graph.nodes[CANVAS_OUTPUT] = { type: 'l2i', id: CANVAS_OUTPUT, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), fp32, use_cache: false, }; @@ -310,7 +312,5 @@ export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph addWatermarkerToGraph(state, graph, CANVAS_OUTPUT); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearImageToImageGraph.ts index c76776d94d3..d1f1546b23b 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearImageToImageGraph.ts @@ -1,10 +1,10 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; +import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; import type { ImageResizeInvocation, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; @@ -117,7 +117,8 @@ export const buildLinearImageToImageGraph = (state: RootState): NonNullableGraph type: 'l2i', id: LATENTS_TO_IMAGE, fp32, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), }, [DENOISE_LATENTS]: { type: 'denoise_latents', @@ -358,7 +359,5 @@ export const buildLinearImageToImageGraph = (state: RootState): NonNullableGraph addWatermarkerToGraph(state, graph); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLImageToImageGraph.ts index 9ae602bcacb..de4ad7ceceb 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLImageToImageGraph.ts @@ -4,7 +4,6 @@ import type { ImageResizeInvocation, ImageToLatentsInvocation, NonNullableGraph import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; @@ -25,7 +24,7 @@ import { SDXL_REFINER_SEAMLESS, SEAMLESS, } from './constants'; -import { getSDXLStylePrompts } from './getSDXLStylePrompt'; +import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils'; import { addCoreMetadataNode } from './metadata'; /** @@ -120,7 +119,8 @@ export const buildLinearSDXLImageToImageGraph = (state: RootState): NonNullableG type: 'l2i', id: LATENTS_TO_IMAGE, fp32, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), }, [SDXL_DENOISE_LATENTS]: { type: 'denoise_latents', @@ -380,7 +380,5 @@ export const buildLinearSDXLImageToImageGraph = (state: RootState): NonNullableG addWatermarkerToGraph(state, graph); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLTextToImageGraph.ts index 222dc1a3595..58b97b07c75 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLTextToImageGraph.ts @@ -4,7 +4,6 @@ import type { NonNullableGraph } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; @@ -23,7 +22,7 @@ import { SDXL_TEXT_TO_IMAGE_GRAPH, SEAMLESS, } from './constants'; -import { getSDXLStylePrompts } from './getSDXLStylePrompt'; +import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils'; import { addCoreMetadataNode } from './metadata'; export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGraph => { @@ -120,7 +119,8 @@ export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGr type: 'l2i', id: LATENTS_TO_IMAGE, fp32, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), use_cache: false, }, }, @@ -281,7 +281,5 @@ export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGr addWatermarkerToGraph(state, graph); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearTextToImageGraph.ts index 0a45d91debc..b2b84cfdad7 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearTextToImageGraph.ts @@ -1,11 +1,11 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; +import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; import type { NonNullableGraph } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addHrfToGraph } from './addHrfToGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; @@ -119,7 +119,8 @@ export const buildLinearTextToImageGraph = (state: RootState): NonNullableGraph type: 'l2i', id: LATENTS_TO_IMAGE, fp32, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), use_cache: false, }, }, @@ -267,7 +268,5 @@ export const buildLinearTextToImageGraph = (state: RootState): NonNullableGraph addWatermarkerToGraph(state, graph); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/constants.ts b/invokeai/frontend/web/src/features/nodes/util/graph/constants.ts index 363d3191210..767bf25df0a 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/constants.ts @@ -9,7 +9,6 @@ export const LATENTS_TO_IMAGE_HRF_LR = 'latents_to_image_hrf_lr'; export const IMAGE_TO_LATENTS_HRF = 'image_to_latents_hrf'; export const RESIZE_HRF = 'resize_hrf'; export const ESRGAN_HRF = 'esrgan_hrf'; -export const LINEAR_UI_OUTPUT = 'linear_ui_output'; export const NSFW_CHECKER = 'nsfw_checker'; export const WATERMARKER = 'invisible_watermark'; export const NOISE = 'noise'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/getSDXLStylePrompt.ts b/invokeai/frontend/web/src/features/nodes/util/graph/getSDXLStylePrompt.ts deleted file mode 100644 index e1cd8518fdd..00000000000 --- a/invokeai/frontend/web/src/features/nodes/util/graph/getSDXLStylePrompt.ts +++ /dev/null @@ -1,11 +0,0 @@ -import type { RootState } from 'app/store/store'; - -export const getSDXLStylePrompts = (state: RootState): { positiveStylePrompt: string; negativeStylePrompt: string } => { - const { positivePrompt, negativePrompt } = state.generation; - const { positiveStylePrompt, negativeStylePrompt, shouldConcatSDXLStylePrompt } = state.sdxl; - - return { - positiveStylePrompt: shouldConcatSDXLStylePrompt ? positivePrompt : positiveStylePrompt, - negativeStylePrompt: shouldConcatSDXLStylePrompt ? negativePrompt : negativeStylePrompt, - }; -}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/graphBuilderUtils.ts b/invokeai/frontend/web/src/features/nodes/util/graph/graphBuilderUtils.ts new file mode 100644 index 00000000000..cb6fc9acf1e --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graph/graphBuilderUtils.ts @@ -0,0 +1,38 @@ +import type { RootState } from 'app/store/store'; +import type { BoardField } from 'features/nodes/types/common'; +import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; + +/** + * Gets the board field, based on the autoAddBoardId setting. + */ +export const getBoardField = (state: RootState): BoardField | undefined => { + const { autoAddBoardId } = state.gallery; + if (autoAddBoardId === 'none') { + return undefined; + } + return { board_id: autoAddBoardId }; +}; + +/** + * Gets the SDXL style prompts, based on the concat setting. + */ +export const getSDXLStylePrompts = (state: RootState): { positiveStylePrompt: string; negativeStylePrompt: string } => { + const { positivePrompt, negativePrompt } = state.generation; + const { positiveStylePrompt, negativeStylePrompt, shouldConcatSDXLStylePrompt } = state.sdxl; + + return { + positiveStylePrompt: shouldConcatSDXLStylePrompt ? positivePrompt : positiveStylePrompt, + negativeStylePrompt: shouldConcatSDXLStylePrompt ? negativePrompt : negativeStylePrompt, + }; +}; + +/** + * Gets the is_intermediate field, based on the active tab and shouldAutoSave setting. + */ +export const getIsIntermediate = (state: RootState) => { + const activeTabName = activeTabNameSelector(state); + if (activeTabName === 'unifiedCanvas') { + return !state.canvas.shouldAutoSave; + } + return false; +}; diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index 1382fbe275a..55ff808b404 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -132,7 +132,6 @@ export type DivideInvocation = s['DivideInvocation']; export type ImageNSFWBlurInvocation = s['ImageNSFWBlurInvocation']; export type ImageWatermarkInvocation = s['ImageWatermarkInvocation']; export type SeamlessModeInvocation = s['SeamlessModeInvocation']; -export type LinearUIOutputInvocation = s['LinearUIOutputInvocation']; export type MetadataInvocation = s['MetadataInvocation']; export type CoreMetadataInvocation = s['CoreMetadataInvocation']; export type MetadataItemInvocation = s['MetadataItemInvocation']; From a5db204629181af352b09ae5af86618ece58b408 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 17:01:39 +1100 Subject: [PATCH 031/100] tidy(nodes): remove unnecessary, shadowing class attr declarations --- invokeai/app/services/invocation_services.py | 27 -------------------- 1 file changed, 27 deletions(-) diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 11a4de99d6e..51bfd5d77a1 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -36,33 +36,6 @@ class InvocationServices: """Services that can be used by invocations""" - # TODO: Just forward-declared everything due to circular dependencies. Fix structure. - board_images: "BoardImagesServiceABC" - board_image_record_storage: "BoardImageRecordStorageBase" - boards: "BoardServiceABC" - board_records: "BoardRecordStorageBase" - configuration: "InvokeAIAppConfig" - events: "EventServiceBase" - graph_execution_manager: "ItemStorageABC[GraphExecutionState]" - images: "ImageServiceABC" - image_records: "ImageRecordStorageBase" - image_files: "ImageFileStorageBase" - latents: "LatentsStorageBase" - logger: "Logger" - model_manager: "ModelManagerServiceBase" - model_records: "ModelRecordServiceBase" - download_queue: "DownloadQueueServiceBase" - model_install: "ModelInstallServiceBase" - processor: "InvocationProcessorABC" - performance_statistics: "InvocationStatsServiceBase" - queue: "InvocationQueueABC" - session_queue: "SessionQueueBase" - session_processor: "SessionProcessorBase" - invocation_cache: "InvocationCacheBase" - names: "NameServiceBase" - urls: "UrlServiceBase" - workflow_records: "WorkflowRecordsStorageBase" - def __init__( self, board_images: "BoardImagesServiceABC", From db6bc7305a4830799a62b21cd3ff46f30bc4f35e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 17:10:25 +1100 Subject: [PATCH 032/100] fix(nodes): rearrange fields.py to avoid needing forward refs --- invokeai/app/invocations/fields.py | 92 +++++++++++++++--------------- 1 file changed, 46 insertions(+), 46 deletions(-) diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index c42d2f83120..40d403c03d9 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -182,6 +182,51 @@ class FieldDescriptions: freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features." +class ImageField(BaseModel): + """An image primitive field""" + + image_name: str = Field(description="The name of the image") + + +class BoardField(BaseModel): + """A board primitive field""" + + board_id: str = Field(description="The id of the board") + + +class DenoiseMaskField(BaseModel): + """An inpaint mask field""" + + mask_name: str = Field(description="The name of the mask image") + masked_latents_name: Optional[str] = Field(default=None, description="The name of the masked image latents") + + +class LatentsField(BaseModel): + """A latents tensor primitive field""" + + latents_name: str = Field(description="The name of the latents") + seed: Optional[int] = Field(default=None, description="Seed used to generate this latents") + + +class ColorField(BaseModel): + """A color primitive field""" + + r: int = Field(ge=0, le=255, description="The red component") + g: int = Field(ge=0, le=255, description="The green component") + b: int = Field(ge=0, le=255, description="The blue component") + a: int = Field(ge=0, le=255, description="The alpha component") + + def tuple(self) -> Tuple[int, int, int, int]: + return (self.r, self.g, self.b, self.a) + + +class ConditioningField(BaseModel): + """A conditioning tensor primitive value""" + + conditioning_name: str = Field(description="The name of conditioning tensor") + # endregion + + class MetadataField(RootModel): """ Pydantic model for metadata with custom root of type dict[str, Any]. @@ -285,7 +330,7 @@ class WithBoard(BaseModel): Inherit from this class if your node needs a board input field. """ - board: Optional["BoardField"] = Field( + board: Optional[BoardField] = Field( default=None, description=FieldDescriptions.board, json_schema_extra=InputFieldJSONSchemaExtra( @@ -518,48 +563,3 @@ def OutputField( field_kind=FieldKind.Output, ).model_dump(exclude_none=True), ) - - -class ImageField(BaseModel): - """An image primitive field""" - - image_name: str = Field(description="The name of the image") - - -class BoardField(BaseModel): - """A board primitive field""" - - board_id: str = Field(description="The id of the board") - - -class DenoiseMaskField(BaseModel): - """An inpaint mask field""" - - mask_name: str = Field(description="The name of the mask image") - masked_latents_name: Optional[str] = Field(default=None, description="The name of the masked image latents") - - -class LatentsField(BaseModel): - """A latents tensor primitive field""" - - latents_name: str = Field(description="The name of the latents") - seed: Optional[int] = Field(default=None, description="Seed used to generate this latents") - - -class ColorField(BaseModel): - """A color primitive field""" - - r: int = Field(ge=0, le=255, description="The red component") - g: int = Field(ge=0, le=255, description="The green component") - b: int = Field(ge=0, le=255, description="The blue component") - a: int = Field(ge=0, le=255, description="The alpha component") - - def tuple(self) -> Tuple[int, int, int, int]: - return (self.r, self.g, self.b, self.a) - - -class ConditioningField(BaseModel): - """A conditioning tensor primitive value""" - - conditioning_name: str = Field(description="The name of conditioning tensor") - # endregion From 2932652787221019eb1b18e5e6df395b2b5491bb Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 17:11:22 +1100 Subject: [PATCH 033/100] tidy(nodes): delete onnx.py It doesn't work and keeping it updated to prevent the app from starting was getting tedious. Deleted. --- invokeai/app/invocations/onnx.py | 510 ------------------------------- 1 file changed, 510 deletions(-) delete mode 100644 invokeai/app/invocations/onnx.py diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py deleted file mode 100644 index e7b4d3d9fc5..00000000000 --- a/invokeai/app/invocations/onnx.py +++ /dev/null @@ -1,510 +0,0 @@ -# Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779) - -import inspect - -# from contextlib import ExitStack -from typing import List, Literal, Union - -import numpy as np -import torch -from diffusers.image_processor import VaeImageProcessor -from pydantic import BaseModel, ConfigDict, Field, field_validator -from tqdm import tqdm - -from invokeai.app.invocations.fields import ( - FieldDescriptions, - Input, - InputField, - OutputField, - UIComponent, - UIType, - WithMetadata, -) -from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin -from invokeai.app.util.step_callback import stable_diffusion_step_callback -from invokeai.backend import BaseModelType, ModelType, SubModelType - -from ...backend.model_management import ONNXModelPatcher -from ...backend.stable_diffusion import PipelineIntermediateState -from ...backend.util import choose_torch_device -from ..util.ti_utils import extract_ti_triggers_from_prompt -from .baseinvocation import ( - BaseInvocation, - BaseInvocationOutput, - InvocationContext, - invocation, - invocation_output, -) -from .controlnet_image_processors import ControlField -from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, get_scheduler -from .model import ClipField, ModelInfo, UNetField, VaeField - -ORT_TO_NP_TYPE = { - "tensor(bool)": np.bool_, - "tensor(int8)": np.int8, - "tensor(uint8)": np.uint8, - "tensor(int16)": np.int16, - "tensor(uint16)": np.uint16, - "tensor(int32)": np.int32, - "tensor(uint32)": np.uint32, - "tensor(int64)": np.int64, - "tensor(uint64)": np.uint64, - "tensor(float16)": np.float16, - "tensor(float)": np.float32, - "tensor(double)": np.float64, -} - -PRECISION_VALUES = Literal[tuple(ORT_TO_NP_TYPE.keys())] - - -@invocation("prompt_onnx", title="ONNX Prompt (Raw)", tags=["prompt", "onnx"], category="conditioning", version="1.0.0") -class ONNXPromptInvocation(BaseInvocation): - prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea) - clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection) - - def invoke(self, context: InvocationContext) -> ConditioningOutput: - tokenizer_info = context.services.model_manager.get_model( - **self.clip.tokenizer.model_dump(), - ) - text_encoder_info = context.services.model_manager.get_model( - **self.clip.text_encoder.model_dump(), - ) - with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder: # , ExitStack() as stack: - loras = [ - ( - context.services.model_manager.get_model(**lora.model_dump(exclude={"weight"})).context.model, - lora.weight, - ) - for lora in self.clip.loras - ] - - ti_list = [] - for trigger in extract_ti_triggers_from_prompt(self.prompt): - name = trigger[1:-1] - try: - ti_list.append( - ( - name, - context.services.model_manager.get_model( - model_name=name, - base_model=self.clip.text_encoder.base_model, - model_type=ModelType.TextualInversion, - ).context.model, - ) - ) - except Exception: - # print(e) - # import traceback - # print(traceback.format_exc()) - print(f'Warn: trigger: "{trigger}" not found') - if loras or ti_list: - text_encoder.release_session() - with ( - ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras), - ONNXModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager), - ): - text_encoder.create_session() - - # copy from - # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L153 - text_inputs = tokenizer( - self.prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="np", - ) - text_input_ids = text_inputs.input_ids - """ - untruncated_ids = tokenizer(prompt, padding="max_length", return_tensors="np").input_ids - - if not np.array_equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - """ - - prompt_embeds = text_encoder(input_ids=text_input_ids.astype(np.int32))[0] - - conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" - - # TODO: hacky but works ;D maybe rename latents somehow? - context.services.latents.save(conditioning_name, (prompt_embeds, None)) - - return ConditioningOutput( - conditioning=ConditioningField( - conditioning_name=conditioning_name, - ), - ) - - -# Text to image -@invocation( - "t2l_onnx", - title="ONNX Text to Latents", - tags=["latents", "inference", "txt2img", "onnx"], - category="latents", - version="1.0.0", -) -class ONNXTextToLatentsInvocation(BaseInvocation): - """Generates latents from conditionings.""" - - positive_conditioning: ConditioningField = InputField( - description=FieldDescriptions.positive_cond, - input=Input.Connection, - ) - negative_conditioning: ConditioningField = InputField( - description=FieldDescriptions.negative_cond, - input=Input.Connection, - ) - noise: LatentsField = InputField( - description=FieldDescriptions.noise, - input=Input.Connection, - ) - steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps) - cfg_scale: Union[float, List[float]] = InputField( - default=7.5, - ge=1, - description=FieldDescriptions.cfg_scale, - ) - scheduler: SAMPLER_NAME_VALUES = InputField( - default="euler", description=FieldDescriptions.scheduler, input=Input.Direct, ui_type=UIType.Scheduler - ) - precision: PRECISION_VALUES = InputField(default="tensor(float16)", description=FieldDescriptions.precision) - unet: UNetField = InputField( - description=FieldDescriptions.unet, - input=Input.Connection, - ) - control: Union[ControlField, list[ControlField]] = InputField( - default=None, - description=FieldDescriptions.control, - ) - # seamless: bool = InputField(default=False, description="Whether or not to generate an image that can tile without seams", ) - # seamless_axes: str = InputField(default="", description="The axes to tile the image on, 'x' and/or 'y'") - - @field_validator("cfg_scale") - def ge_one(cls, v): - """validate that all cfg_scale values are >= 1""" - if isinstance(v, list): - for i in v: - if i < 1: - raise ValueError("cfg_scale must be greater than 1") - else: - if v < 1: - raise ValueError("cfg_scale must be greater than 1") - return v - - # based on - # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375 - def invoke(self, context: InvocationContext) -> LatentsOutput: - c, _ = context.services.latents.get(self.positive_conditioning.conditioning_name) - uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name) - graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) - source_node_id = graph_execution_state.prepared_source_mapping[self.id] - if isinstance(c, torch.Tensor): - c = c.cpu().numpy() - if isinstance(uc, torch.Tensor): - uc = uc.cpu().numpy() - device = torch.device(choose_torch_device()) - prompt_embeds = np.concatenate([uc, c]) - - latents = context.services.latents.get(self.noise.latents_name) - if isinstance(latents, torch.Tensor): - latents = latents.cpu().numpy() - - # TODO: better execution device handling - latents = latents.astype(ORT_TO_NP_TYPE[self.precision]) - - # get the initial random noise unless the user supplied it - do_classifier_free_guidance = True - # latents_dtype = prompt_embeds.dtype - # latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8) - # if latents.shape != latents_shape: - # raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - - scheduler = get_scheduler( - context=context, - scheduler_info=self.unet.scheduler, - scheduler_name=self.scheduler, - seed=0, # TODO: refactor this node - ) - - def torch2numpy(latent: torch.Tensor): - return latent.cpu().numpy() - - def numpy2torch(latent, device): - return torch.from_numpy(latent).to(device) - - def dispatch_progress( - self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState - ) -> None: - stable_diffusion_step_callback( - context=context, - intermediate_state=intermediate_state, - node=self.model_dump(), - source_node_id=source_node_id, - ) - - scheduler.set_timesteps(self.steps) - latents = latents * np.float64(scheduler.init_noise_sigma) - - extra_step_kwargs = {} - if "eta" in set(inspect.signature(scheduler.step).parameters.keys()): - extra_step_kwargs.update( - eta=0.0, - ) - - unet_info = context.services.model_manager.get_model(**self.unet.unet.model_dump()) - - with unet_info as unet: # , ExitStack() as stack: - # loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras] - loras = [ - ( - context.services.model_manager.get_model(**lora.model_dump(exclude={"weight"})).context.model, - lora.weight, - ) - for lora in self.unet.loras - ] - - if loras: - unet.release_session() - with ONNXModelPatcher.apply_lora_unet(unet, loras): - # TODO: - _, _, h, w = latents.shape - unet.create_session(h, w) - - timestep_dtype = next( - (input.type for input in unet.session.get_inputs() if input.name == "timestep"), "tensor(float16)" - ) - timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] - for i in tqdm(range(len(scheduler.timesteps))): - t = scheduler.timesteps[i] - # expand the latents if we are doing classifier free guidance - latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = scheduler.scale_model_input(numpy2torch(latent_model_input, device), t) - latent_model_input = latent_model_input.cpu().numpy() - - # predict the noise residual - timestep = np.array([t], dtype=timestep_dtype) - noise_pred = unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds) - noise_pred = noise_pred[0] - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) - noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - scheduler_output = scheduler.step( - numpy2torch(noise_pred, device), t, numpy2torch(latents, device), **extra_step_kwargs - ) - latents = torch2numpy(scheduler_output.prev_sample) - - state = PipelineIntermediateState( - run_id="test", step=i, timestep=timestep, latents=scheduler_output.prev_sample - ) - dispatch_progress(self, context=context, source_node_id=source_node_id, intermediate_state=state) - - # call the callback, if provided - # if callback is not None and i % callback_steps == 0: - # callback(i, t, latents) - - torch.cuda.empty_cache() - - name = f"{context.graph_execution_state_id}__{self.id}" - context.services.latents.save(name, latents) - # return build_latents_output(latents_name=name, latents=torch.from_numpy(latents)) - - -# Latent to image -@invocation( - "l2i_onnx", - title="ONNX Latents to Image", - tags=["latents", "image", "vae", "onnx"], - category="image", - version="1.2.0", -) -class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata): - """Generates an image from latents.""" - - latents: LatentsField = InputField( - description=FieldDescriptions.denoised_latents, - input=Input.Connection, - ) - vae: VaeField = InputField( - description=FieldDescriptions.vae, - input=Input.Connection, - ) - # tiled: bool = InputField(default=False, description="Decode latents by overlaping tiles(less memory consumption)") - - def invoke(self, context: InvocationContext) -> ImageOutput: - latents = context.services.latents.get(self.latents.latents_name) - - if self.vae.vae.submodel != SubModelType.VaeDecoder: - raise Exception(f"Expected vae_decoder, found: {self.vae.vae.model_type}") - - vae_info = context.services.model_manager.get_model( - **self.vae.vae.model_dump(), - ) - - # clear memory as vae decode can request a lot - torch.cuda.empty_cache() - - with vae_info as vae: - vae.create_session() - - # copied from - # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L427 - latents = 1 / 0.18215 * latents - # image = self.vae_decoder(latent_sample=latents)[0] - # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 - image = np.concatenate([vae(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]) - - image = np.clip(image / 2 + 0.5, 0, 1) - image = image.transpose((0, 2, 3, 1)) - image = VaeImageProcessor.numpy_to_pil(image)[0] - - torch.cuda.empty_cache() - - image_dto = context.services.images.create( - image=image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) - - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) - - -@invocation_output("model_loader_output_onnx") -class ONNXModelLoaderOutput(BaseInvocationOutput): - """Model loader output""" - - unet: UNetField = OutputField(default=None, description=FieldDescriptions.unet, title="UNet") - clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP") - vae_decoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Decoder") - vae_encoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Encoder") - - -class OnnxModelField(BaseModel): - """Onnx model field""" - - model_name: str = Field(description="Name of the model") - base_model: BaseModelType = Field(description="Base model") - model_type: ModelType = Field(description="Model Type") - - model_config = ConfigDict(protected_namespaces=()) - - -@invocation("onnx_model_loader", title="ONNX Main Model", tags=["onnx", "model"], category="model", version="1.0.0") -class OnnxModelLoaderInvocation(BaseInvocation): - """Loads a main model, outputting its submodels.""" - - model: OnnxModelField = InputField( - description=FieldDescriptions.onnx_main_model, input=Input.Direct, ui_type=UIType.ONNXModel - ) - - def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput: - base_model = self.model.base_model - model_name = self.model.model_name - model_type = ModelType.ONNX - - # TODO: not found exceptions - if not context.services.model_manager.model_exists( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ): - raise Exception(f"Unknown {base_model} {model_type} model: {model_name}") - - """ - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.Tokenizer, - ): - raise Exception( - f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted" - ) - - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.TextEncoder, - ): - raise Exception( - f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted" - ) - - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.UNet, - ): - raise Exception( - f"Failed to find unet submodel from {self.model_name}! Check if model corrupted" - ) - """ - - return ONNXModelLoaderOutput( - unet=UNetField( - unet=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.UNet, - ), - scheduler=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.Scheduler, - ), - loras=[], - ), - clip=ClipField( - tokenizer=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.Tokenizer, - ), - text_encoder=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.TextEncoder, - ), - loras=[], - skipped_layers=0, - ), - vae_decoder=VaeField( - vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.VaeDecoder, - ), - ), - vae_encoder=VaeField( - vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.VaeEncoder, - ), - ), - ) From de0b72528cbb3f7ffcdafcad5c3232e81e748fd2 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 17:41:23 +1100 Subject: [PATCH 034/100] feat(nodes): replace latents service with tensors and conditioning services - New generic class `PickleStorageBase`, implements the same API as `LatentsStorageBase`, use for storing non-serializable data via pickling - Implementation `PickleStorageTorch` uses `torch.save` and `torch.load`, same as `LatentsStorageDisk` - Add `tensors: PickleStorageBase[torch.Tensor]` to `InvocationServices` - Add `conditioning: PickleStorageBase[ConditioningFieldData]` to `InvocationServices` - Remove `latents` service and all `LatentsStorage` classes - Update `InvocationContext` and all usage of old `latents` service to use the new services/context wrapper methods --- invokeai/app/api/dependencies.py | 18 +++-- invokeai/app/invocations/latent.py | 36 +++++----- invokeai/app/invocations/noise.py | 2 +- invokeai/app/invocations/primitives.py | 6 +- .../invocation_cache_memory.py | 3 +- invokeai/app/services/invocation_services.py | 12 +++- .../app/services/latents_storage/__init__.py | 0 .../latents_storage/latents_storage_disk.py | 58 ---------------- .../latents_storage_forward_cache.py | 68 ------------------- .../pickle_storage_base.py} | 18 ++--- .../pickle_storage_forward_cache.py | 58 ++++++++++++++++ .../pickle_storage/pickle_storage_torch.py | 62 +++++++++++++++++ .../app/services/shared/invocation_context.py | 49 ++++++------- 13 files changed, 197 insertions(+), 193 deletions(-) delete mode 100644 invokeai/app/services/latents_storage/__init__.py delete mode 100644 invokeai/app/services/latents_storage/latents_storage_disk.py delete mode 100644 invokeai/app/services/latents_storage/latents_storage_forward_cache.py rename invokeai/app/services/{latents_storage/latents_storage_base.py => pickle_storage/pickle_storage_base.py} (68%) create mode 100644 invokeai/app/services/pickle_storage/pickle_storage_forward_cache.py create mode 100644 invokeai/app/services/pickle_storage/pickle_storage_torch.py diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index c8309e1729e..6bb0915cb6e 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -2,9 +2,14 @@ from logging import Logger +import torch + from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory +from invokeai.app.services.pickle_storage.pickle_storage_forward_cache import PickleStorageForwardCache +from invokeai.app.services.pickle_storage.pickle_storage_torch import PickleStorageTorch from invokeai.app.services.shared.sqlite.sqlite_util import init_db from invokeai.backend.model_manager.metadata import ModelMetadataStore +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData from invokeai.backend.util.logging import InvokeAILogger from invokeai.version.invokeai_version import __version__ @@ -23,8 +28,6 @@ from ..services.invocation_services import InvocationServices from ..services.invocation_stats.invocation_stats_default import InvocationStatsService from ..services.invoker import Invoker -from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage -from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage from ..services.model_install import ModelInstallService from ..services.model_manager.model_manager_default import ModelManagerService from ..services.model_records import ModelRecordServiceSQL @@ -68,6 +71,9 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger logger.debug(f"Internet connectivity is {config.internet_available}") output_folder = config.output_path + if output_folder is None: + raise ValueError("Output folder is not set") + image_files = DiskImageFileStorage(f"{output_folder}/images") db = init_db(config=config, logger=logger, image_files=image_files) @@ -84,7 +90,10 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger image_records = SqliteImageRecordStorage(db=db) images = ImageService() invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) - latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")) + tensors = PickleStorageForwardCache(PickleStorageTorch[torch.Tensor](output_folder / "tensors", "tensor")) + conditioning = PickleStorageForwardCache( + PickleStorageTorch[ConditioningFieldData](output_folder / "conditioning", "conditioning") + ) model_manager = ModelManagerService(config, logger) model_record_service = ModelRecordServiceSQL(db=db) download_queue_service = DownloadQueueService(event_bus=events) @@ -117,7 +126,6 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger image_records=image_records, images=images, invocation_cache=invocation_cache, - latents=latents, logger=logger, model_manager=model_manager, model_records=model_record_service, @@ -131,6 +139,8 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger session_queue=session_queue, urls=urls, workflow_records=workflow_records, + tensors=tensors, + conditioning=conditioning, ) ApiDependencies.invoker = Invoker(services) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 5449ec9af7a..94440d3e2aa 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -163,11 +163,11 @@ def invoke(self, context: InvocationContext) -> DenoiseMaskOutput: # TODO: masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone()) - masked_latents_name = context.latents.save(tensor=masked_latents) + masked_latents_name = context.tensors.save(tensor=masked_latents) else: masked_latents_name = None - mask_name = context.latents.save(tensor=mask) + mask_name = context.tensors.save(tensor=mask) return DenoiseMaskOutput.build( mask_name=mask_name, @@ -621,10 +621,10 @@ def prep_inpaint_mask(self, context: InvocationContext, latents): if self.denoise_mask is None: return None, None - mask = context.latents.get(self.denoise_mask.mask_name) + mask = context.tensors.get(self.denoise_mask.mask_name) mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) if self.denoise_mask.masked_latents_name is not None: - masked_latents = context.latents.get(self.denoise_mask.masked_latents_name) + masked_latents = context.tensors.get(self.denoise_mask.masked_latents_name) else: masked_latents = None @@ -636,11 +636,11 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: seed = None noise = None if self.noise is not None: - noise = context.latents.get(self.noise.latents_name) + noise = context.tensors.get(self.noise.latents_name) seed = self.noise.seed if self.latents is not None: - latents = context.latents.get(self.latents.latents_name) + latents = context.tensors.get(self.latents.latents_name) if seed is None: seed = self.latents.seed @@ -752,7 +752,7 @@ def _lora_loader(): if choose_torch_device() == torch.device("mps"): mps.empty_cache() - name = context.latents.save(tensor=result_latents) + name = context.tensors.save(tensor=result_latents) return LatentsOutput.build(latents_name=name, latents=result_latents, seed=seed) @@ -779,7 +779,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): @torch.no_grad() def invoke(self, context: InvocationContext) -> ImageOutput: - latents = context.latents.get(self.latents.latents_name) + latents = context.tensors.get(self.latents.latents_name) vae_info = context.models.load(**self.vae.vae.model_dump()) @@ -870,7 +870,7 @@ class ResizeLatentsInvocation(BaseInvocation): antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.latents.get(self.latents.latents_name) + latents = context.tensors.get(self.latents.latents_name) # TODO: device = choose_torch_device() @@ -888,7 +888,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: if device == torch.device("mps"): mps.empty_cache() - name = context.latents.save(tensor=resized_latents) + name = context.tensors.save(tensor=resized_latents) return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed) @@ -911,7 +911,7 @@ class ScaleLatentsInvocation(BaseInvocation): antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.latents.get(self.latents.latents_name) + latents = context.tensors.get(self.latents.latents_name) # TODO: device = choose_torch_device() @@ -930,7 +930,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: if device == torch.device("mps"): mps.empty_cache() - name = context.latents.save(tensor=resized_latents) + name = context.tensors.save(tensor=resized_latents) return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed) @@ -1011,7 +1011,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor) latents = latents.to("cpu") - name = context.latents.save(tensor=latents) + name = context.tensors.save(tensor=latents) return LatentsOutput.build(latents_name=name, latents=latents, seed=None) @singledispatchmethod @@ -1048,8 +1048,8 @@ class BlendLatentsInvocation(BaseInvocation): alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents_a = context.latents.get(self.latents_a.latents_name) - latents_b = context.latents.get(self.latents_b.latents_name) + latents_a = context.tensors.get(self.latents_a.latents_name) + latents_b = context.tensors.get(self.latents_b.latents_name) if latents_a.shape != latents_b.shape: raise Exception("Latents to blend must be the same size.") @@ -1103,7 +1103,7 @@ def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): if device == torch.device("mps"): mps.empty_cache() - name = context.latents.save(tensor=blended_latents) + name = context.tensors.save(tensor=blended_latents) return LatentsOutput.build(latents_name=name, latents=blended_latents) @@ -1149,7 +1149,7 @@ class CropLatentsCoreInvocation(BaseInvocation): ) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.latents.get(self.latents.latents_name) + latents = context.tensors.get(self.latents.latents_name) x1 = self.x // LATENT_SCALE_FACTOR y1 = self.y // LATENT_SCALE_FACTOR @@ -1158,7 +1158,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: cropped_latents = latents[..., y1:y2, x1:x2] - name = context.latents.save(tensor=cropped_latents) + name = context.tensors.save(tensor=cropped_latents) return LatentsOutput.build(latents_name=name, latents=cropped_latents) diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index 78f13cc52d1..74b3d6e4cb1 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -121,5 +121,5 @@ def invoke(self, context: InvocationContext) -> NoiseOutput: seed=self.seed, use_cpu=self.use_cpu, ) - name = context.latents.save(tensor=noise) + name = context.tensors.save(tensor=noise) return NoiseOutput.build(latents_name=name, latents=noise, seed=self.seed) diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index a77939943ae..082d5432ccf 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -313,9 +313,7 @@ def build(cls, mask_name: str, masked_latents_name: Optional[str] = None) -> "De class LatentsOutput(BaseInvocationOutput): """Base class for nodes that output a single latents tensor""" - latents: LatentsField = OutputField( - description=FieldDescriptions.latents, - ) + latents: LatentsField = OutputField(description=FieldDescriptions.latents) width: int = OutputField(description=FieldDescriptions.width) height: int = OutputField(description=FieldDescriptions.height) @@ -346,7 +344,7 @@ class LatentsInvocation(BaseInvocation): latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.latents.get(self.latents.latents_name) + latents = context.tensors.get(self.latents.latents_name) return LatentsOutput.build(self.latents.latents_name, latents) diff --git a/invokeai/app/services/invocation_cache/invocation_cache_memory.py b/invokeai/app/services/invocation_cache/invocation_cache_memory.py index 4a503b3c6b1..c700f81186f 100644 --- a/invokeai/app/services/invocation_cache/invocation_cache_memory.py +++ b/invokeai/app/services/invocation_cache/invocation_cache_memory.py @@ -37,7 +37,8 @@ def start(self, invoker: Invoker) -> None: if self._max_cache_size == 0: return self._invoker.services.images.on_deleted(self._delete_by_match) - self._invoker.services.latents.on_deleted(self._delete_by_match) + self._invoker.services.tensors.on_deleted(self._delete_by_match) + self._invoker.services.conditioning.on_deleted(self._delete_by_match) def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]: with self._lock: diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 51bfd5d77a1..81885781acb 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -6,6 +6,10 @@ if TYPE_CHECKING: from logging import Logger + import torch + + from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData + from .board_image_records.board_image_records_base import BoardImageRecordStorageBase from .board_images.board_images_base import BoardImagesServiceABC from .board_records.board_records_base import BoardRecordStorageBase @@ -21,11 +25,11 @@ from .invocation_queue.invocation_queue_base import InvocationQueueABC from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase from .item_storage.item_storage_base import ItemStorageABC - from .latents_storage.latents_storage_base import LatentsStorageBase from .model_install import ModelInstallServiceBase from .model_manager.model_manager_base import ModelManagerServiceBase from .model_records import ModelRecordServiceBase from .names.names_base import NameServiceBase + from .pickle_storage.pickle_storage_base import PickleStorageBase from .session_processor.session_processor_base import SessionProcessorBase from .session_queue.session_queue_base import SessionQueueBase from .shared.graph import GraphExecutionState @@ -48,7 +52,6 @@ def __init__( images: "ImageServiceABC", image_files: "ImageFileStorageBase", image_records: "ImageRecordStorageBase", - latents: "LatentsStorageBase", logger: "Logger", model_manager: "ModelManagerServiceBase", model_records: "ModelRecordServiceBase", @@ -63,6 +66,8 @@ def __init__( names: "NameServiceBase", urls: "UrlServiceBase", workflow_records: "WorkflowRecordsStorageBase", + tensors: "PickleStorageBase[torch.Tensor]", + conditioning: "PickleStorageBase[ConditioningFieldData]", ): self.board_images = board_images self.board_image_records = board_image_records @@ -74,7 +79,6 @@ def __init__( self.images = images self.image_files = image_files self.image_records = image_records - self.latents = latents self.logger = logger self.model_manager = model_manager self.model_records = model_records @@ -89,3 +93,5 @@ def __init__( self.names = names self.urls = urls self.workflow_records = workflow_records + self.tensors = tensors + self.conditioning = conditioning diff --git a/invokeai/app/services/latents_storage/__init__.py b/invokeai/app/services/latents_storage/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/invokeai/app/services/latents_storage/latents_storage_disk.py b/invokeai/app/services/latents_storage/latents_storage_disk.py deleted file mode 100644 index 9192b9147f7..00000000000 --- a/invokeai/app/services/latents_storage/latents_storage_disk.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) - -from pathlib import Path -from typing import Union - -import torch - -from invokeai.app.services.invoker import Invoker - -from .latents_storage_base import LatentsStorageBase - - -class DiskLatentsStorage(LatentsStorageBase): - """Stores latents in a folder on disk without caching""" - - __output_folder: Path - - def __init__(self, output_folder: Union[str, Path]): - self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder) - self.__output_folder.mkdir(parents=True, exist_ok=True) - - def start(self, invoker: Invoker) -> None: - self._invoker = invoker - self._delete_all_latents() - - def get(self, name: str) -> torch.Tensor: - latent_path = self.get_path(name) - return torch.load(latent_path) - - def save(self, name: str, data: torch.Tensor) -> None: - self.__output_folder.mkdir(parents=True, exist_ok=True) - latent_path = self.get_path(name) - torch.save(data, latent_path) - - def delete(self, name: str) -> None: - latent_path = self.get_path(name) - latent_path.unlink() - - def get_path(self, name: str) -> Path: - return self.__output_folder / name - - def _delete_all_latents(self) -> None: - """ - Deletes all latents from disk. - Must be called after we have access to `self._invoker` (e.g. in `start()`). - """ - deleted_latents_count = 0 - freed_space = 0 - for latents_file in Path(self.__output_folder).glob("*"): - if latents_file.is_file(): - freed_space += latents_file.stat().st_size - deleted_latents_count += 1 - latents_file.unlink() - if deleted_latents_count > 0: - freed_space_in_mb = round(freed_space / 1024 / 1024, 2) - self._invoker.services.logger.info( - f"Deleted {deleted_latents_count} latents files (freed {freed_space_in_mb}MB)" - ) diff --git a/invokeai/app/services/latents_storage/latents_storage_forward_cache.py b/invokeai/app/services/latents_storage/latents_storage_forward_cache.py deleted file mode 100644 index 6232b76a27d..00000000000 --- a/invokeai/app/services/latents_storage/latents_storage_forward_cache.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) - -from queue import Queue -from typing import Dict, Optional - -import torch - -from invokeai.app.services.invoker import Invoker - -from .latents_storage_base import LatentsStorageBase - - -class ForwardCacheLatentsStorage(LatentsStorageBase): - """Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage""" - - __cache: Dict[str, torch.Tensor] - __cache_ids: Queue - __max_cache_size: int - __underlying_storage: LatentsStorageBase - - def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20): - super().__init__() - self.__underlying_storage = underlying_storage - self.__cache = {} - self.__cache_ids = Queue() - self.__max_cache_size = max_cache_size - - def start(self, invoker: Invoker) -> None: - self._invoker = invoker - start_op = getattr(self.__underlying_storage, "start", None) - if callable(start_op): - start_op(invoker) - - def stop(self, invoker: Invoker) -> None: - self._invoker = invoker - stop_op = getattr(self.__underlying_storage, "stop", None) - if callable(stop_op): - stop_op(invoker) - - def get(self, name: str) -> torch.Tensor: - cache_item = self.__get_cache(name) - if cache_item is not None: - return cache_item - - latent = self.__underlying_storage.get(name) - self.__set_cache(name, latent) - return latent - - def save(self, name: str, data: torch.Tensor) -> None: - self.__underlying_storage.save(name, data) - self.__set_cache(name, data) - self._on_changed(data) - - def delete(self, name: str) -> None: - self.__underlying_storage.delete(name) - if name in self.__cache: - del self.__cache[name] - self._on_deleted(name) - - def __get_cache(self, name: str) -> Optional[torch.Tensor]: - return None if name not in self.__cache else self.__cache[name] - - def __set_cache(self, name: str, data: torch.Tensor): - if name not in self.__cache: - self.__cache[name] = data - self.__cache_ids.put(name) - if self.__cache_ids.qsize() > self.__max_cache_size: - self.__cache.pop(self.__cache_ids.get()) diff --git a/invokeai/app/services/latents_storage/latents_storage_base.py b/invokeai/app/services/pickle_storage/pickle_storage_base.py similarity index 68% rename from invokeai/app/services/latents_storage/latents_storage_base.py rename to invokeai/app/services/pickle_storage/pickle_storage_base.py index 9fa42b0ae61..558b97c0f1b 100644 --- a/invokeai/app/services/latents_storage/latents_storage_base.py +++ b/invokeai/app/services/pickle_storage/pickle_storage_base.py @@ -1,15 +1,15 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) from abc import ABC, abstractmethod -from typing import Callable +from typing import Callable, Generic, TypeVar -import torch +T = TypeVar("T") -class LatentsStorageBase(ABC): - """Responsible for storing and retrieving latents.""" +class PickleStorageBase(ABC, Generic[T]): + """Responsible for storing and retrieving non-serializable data using a pickler.""" - _on_changed_callbacks: list[Callable[[torch.Tensor], None]] + _on_changed_callbacks: list[Callable[[T], None]] _on_deleted_callbacks: list[Callable[[str], None]] def __init__(self) -> None: @@ -17,18 +17,18 @@ def __init__(self) -> None: self._on_deleted_callbacks = [] @abstractmethod - def get(self, name: str) -> torch.Tensor: + def get(self, name: str) -> T: pass @abstractmethod - def save(self, name: str, data: torch.Tensor) -> None: + def save(self, name: str, data: T) -> None: pass @abstractmethod def delete(self, name: str) -> None: pass - def on_changed(self, on_changed: Callable[[torch.Tensor], None]) -> None: + def on_changed(self, on_changed: Callable[[T], None]) -> None: """Register a callback for when an item is changed""" self._on_changed_callbacks.append(on_changed) @@ -36,7 +36,7 @@ def on_deleted(self, on_deleted: Callable[[str], None]) -> None: """Register a callback for when an item is deleted""" self._on_deleted_callbacks.append(on_deleted) - def _on_changed(self, item: torch.Tensor) -> None: + def _on_changed(self, item: T) -> None: for callback in self._on_changed_callbacks: callback(item) diff --git a/invokeai/app/services/pickle_storage/pickle_storage_forward_cache.py b/invokeai/app/services/pickle_storage/pickle_storage_forward_cache.py new file mode 100644 index 00000000000..3002d9e045d --- /dev/null +++ b/invokeai/app/services/pickle_storage/pickle_storage_forward_cache.py @@ -0,0 +1,58 @@ +from queue import Queue +from typing import Optional, TypeVar + +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.pickle_storage.pickle_storage_base import PickleStorageBase + +T = TypeVar("T") + + +class PickleStorageForwardCache(PickleStorageBase[T]): + def __init__(self, underlying_storage: PickleStorageBase[T], max_cache_size: int = 20): + super().__init__() + self._underlying_storage = underlying_storage + self._cache: dict[str, T] = {} + self._cache_ids = Queue[str]() + self._max_cache_size = max_cache_size + + def start(self, invoker: Invoker) -> None: + self._invoker = invoker + start_op = getattr(self._underlying_storage, "start", None) + if callable(start_op): + start_op(invoker) + + def stop(self, invoker: Invoker) -> None: + self._invoker = invoker + stop_op = getattr(self._underlying_storage, "stop", None) + if callable(stop_op): + stop_op(invoker) + + def get(self, name: str) -> T: + cache_item = self._get_cache(name) + if cache_item is not None: + return cache_item + + latent = self._underlying_storage.get(name) + self._set_cache(name, latent) + return latent + + def save(self, name: str, data: T) -> None: + self._underlying_storage.save(name, data) + self._set_cache(name, data) + self._on_changed(data) + + def delete(self, name: str) -> None: + self._underlying_storage.delete(name) + if name in self._cache: + del self._cache[name] + self._on_deleted(name) + + def _get_cache(self, name: str) -> Optional[T]: + return None if name not in self._cache else self._cache[name] + + def _set_cache(self, name: str, data: T): + if name not in self._cache: + self._cache[name] = data + self._cache_ids.put(name) + if self._cache_ids.qsize() > self._max_cache_size: + self._cache.pop(self._cache_ids.get()) diff --git a/invokeai/app/services/pickle_storage/pickle_storage_torch.py b/invokeai/app/services/pickle_storage/pickle_storage_torch.py new file mode 100644 index 00000000000..0b3c9af7a33 --- /dev/null +++ b/invokeai/app/services/pickle_storage/pickle_storage_torch.py @@ -0,0 +1,62 @@ +# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) + +from pathlib import Path +from typing import TypeVar + +import torch + +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.pickle_storage.pickle_storage_base import PickleStorageBase + +T = TypeVar("T") + + +class PickleStorageTorch(PickleStorageBase[T]): + """Responsible for storing and retrieving non-serializable data using `torch.save` and `torch.load`.""" + + def __init__(self, output_folder: Path, item_type_name: "str"): + self._output_folder = output_folder + self._output_folder.mkdir(parents=True, exist_ok=True) + self._item_type_name = item_type_name + + def start(self, invoker: Invoker) -> None: + self._invoker = invoker + self._delete_all_items() + + def get(self, name: str) -> T: + latent_path = self._get_path(name) + return torch.load(latent_path) + + def save(self, name: str, data: T) -> None: + self._output_folder.mkdir(parents=True, exist_ok=True) + latent_path = self._get_path(name) + torch.save(data, latent_path) + + def delete(self, name: str) -> None: + latent_path = self._get_path(name) + latent_path.unlink() + + def _get_path(self, name: str) -> Path: + return self._output_folder / name + + def _delete_all_items(self) -> None: + """ + Deletes all pickled items from disk. + Must be called after we have access to `self._invoker` (e.g. in `start()`). + """ + + if not self._invoker: + raise ValueError("Invoker is not set. Must call `start()` first.") + + deleted_latents_count = 0 + freed_space = 0 + for latents_file in Path(self._output_folder).glob("*"): + if latents_file.is_file(): + freed_space += latents_file.stat().st_size + deleted_latents_count += 1 + latents_file.unlink() + if deleted_latents_count > 0: + freed_space_in_mb = round(freed_space / 1024 / 1024, 2) + self._invoker.services.logger.info( + f"Deleted {deleted_latents_count} {self._item_type_name} files (freed {freed_space_in_mb}MB)" + ) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 97a62246fbc..6756b1f5c6c 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -216,48 +216,46 @@ def get_dto(self, image_name: str) -> ImageDTO: return self._services.images.get_dto(image_name) -class LatentsInterface(InvocationContextInterface): +class TensorsInterface(InvocationContextInterface): def save(self, tensor: Tensor) -> str: """ - Saves a latents tensor, returning its name. + Saves a tensor, returning its name. - :param tensor: The latents tensor to save. + :param tensor: The tensor to save. """ # Previously, we added a suffix indicating the type of Tensor we were saving, e.g. # "mask", "noise", "masked_latents", etc. # # Retaining that capability in this wrapper would require either many different methods - # to save latents, or extra args for this method. Instead of complicating the API, we - # will use the same naming scheme for all latents. + # to save tensors, or extra args for this method. Instead of complicating the API, we + # will use the same naming scheme for all tensors. # # This has a very minor impact as we don't use them after a session completes. - # Previously, invocations chose the name for their latents. This is a bit risky, so we + # Previously, invocations chose the name for their tensors. This is a bit risky, so we # will generate a name for them instead. We use a uuid to ensure the name is unique. # - # Because the name of the latents file will includes the session and invocation IDs, + # Because the name of the tensors file will includes the session and invocation IDs, # we don't need to worry about collisions. A truncated UUIDv4 is fine. name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}" - self._services.latents.save( + self._services.tensors.save( name=name, data=tensor, ) return name - def get(self, latents_name: str) -> Tensor: + def get(self, tensor_name: str) -> Tensor: """ - Gets a latents tensor by name. + Gets a tensor by name. - :param latents_name: The name of the latents tensor to get. + :param tensor_name: The name of the tensor to get. """ - return self._services.latents.get(latents_name) + return self._services.tensors.get(tensor_name) class ConditioningInterface(InvocationContextInterface): - # TODO(psyche): We are (ab)using the latents storage service as a general pickle storage - # service, but it is typed to work with Tensors only. We have to fudge the types here. def save(self, conditioning_data: ConditioningFieldData) -> str: """ Saves a conditioning data object, returning its name. @@ -265,15 +263,12 @@ def save(self, conditioning_data: ConditioningFieldData) -> str: :param conditioning_context_data: The conditioning data to save. """ - # Conditioning data is *not* a Tensor, so we will suffix it to indicate this. - # - # See comment for `LatentsInterface.save` for more info about this method (it's very - # similar). + # See comment in TensorsInterface.save for why we generate the name here. - name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}__conditioning" - self._services.latents.save( + name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}" + self._services.conditioning.save( name=name, - data=conditioning_data, # type: ignore [arg-type] + data=conditioning_data, ) return name @@ -284,7 +279,7 @@ def get(self, conditioning_name: str) -> ConditioningFieldData: :param conditioning_name: The name of the conditioning data to get. """ - return self._services.latents.get(conditioning_name) # type: ignore [return-value] + return self._services.conditioning.get(conditioning_name) class ModelsInterface(InvocationContextInterface): @@ -400,7 +395,7 @@ class InvocationContext: def __init__( self, images: ImagesInterface, - latents: LatentsInterface, + tensors: TensorsInterface, conditioning: ConditioningInterface, models: ModelsInterface, logger: LoggerInterface, @@ -412,8 +407,8 @@ def __init__( ) -> None: self.images = images """Provides methods to save, get and update images and their metadata.""" - self.latents = latents - """Provides methods to save and get latents tensors, including image, noise, masks, and masked images.""" + self.tensors = tensors + """Provides methods to save and get tensors, including image, noise, masks, and masked images.""" self.conditioning = conditioning """Provides methods to save and get conditioning data.""" self.models = models @@ -532,7 +527,7 @@ def build_invocation_context( logger = LoggerInterface(services=services, context_data=context_data) images = ImagesInterface(services=services, context_data=context_data) - latents = LatentsInterface(services=services, context_data=context_data) + tensors = TensorsInterface(services=services, context_data=context_data) models = ModelsInterface(services=services, context_data=context_data) config = ConfigInterface(services=services, context_data=context_data) util = UtilInterface(services=services, context_data=context_data) @@ -543,7 +538,7 @@ def build_invocation_context( images=images, logger=logger, config=config, - latents=latents, + tensors=tensors, models=models, context_data=context_data, util=util, From a7f91b3e01e2f9970cbf352ea3f9b8f85d76247b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 18:17:23 +1100 Subject: [PATCH 035/100] tidy(nodes): do not refer to files as latents in `PickleStorageTorch` --- .../pickle_storage/pickle_storage_torch.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/invokeai/app/services/pickle_storage/pickle_storage_torch.py b/invokeai/app/services/pickle_storage/pickle_storage_torch.py index 0b3c9af7a33..7b18dc0625e 100644 --- a/invokeai/app/services/pickle_storage/pickle_storage_torch.py +++ b/invokeai/app/services/pickle_storage/pickle_storage_torch.py @@ -48,15 +48,15 @@ def _delete_all_items(self) -> None: if not self._invoker: raise ValueError("Invoker is not set. Must call `start()` first.") - deleted_latents_count = 0 + deleted_count = 0 freed_space = 0 - for latents_file in Path(self._output_folder).glob("*"): - if latents_file.is_file(): - freed_space += latents_file.stat().st_size - deleted_latents_count += 1 - latents_file.unlink() - if deleted_latents_count > 0: + for file in Path(self._output_folder).glob("*"): + if file.is_file(): + freed_space += file.stat().st_size + deleted_count += 1 + file.unlink() + if deleted_count > 0: freed_space_in_mb = round(freed_space / 1024 / 1024, 2) self._invoker.services.logger.info( - f"Deleted {deleted_latents_count} {self._item_type_name} files (freed {freed_space_in_mb}MB)" + f"Deleted {deleted_count} {self._item_type_name} files (freed {freed_space_in_mb}MB)" ) From 0266946d3d86fd4d85d7088a93faa7945b1c6144 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 18:35:58 +1100 Subject: [PATCH 036/100] fix(nodes): add super init to `PickleStorageTorch` --- invokeai/app/services/pickle_storage/pickle_storage_torch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/invokeai/app/services/pickle_storage/pickle_storage_torch.py b/invokeai/app/services/pickle_storage/pickle_storage_torch.py index 7b18dc0625e..de411bbf47d 100644 --- a/invokeai/app/services/pickle_storage/pickle_storage_torch.py +++ b/invokeai/app/services/pickle_storage/pickle_storage_torch.py @@ -15,6 +15,7 @@ class PickleStorageTorch(PickleStorageBase[T]): """Responsible for storing and retrieving non-serializable data using `torch.save` and `torch.load`.""" def __init__(self, output_folder: Path, item_type_name: "str"): + super().__init__() self._output_folder = output_folder self._output_folder.mkdir(parents=True, exist_ok=True) self._item_type_name = item_type_name From fd30cb4d9057c5e7244df0c8e157128e6a2188f6 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 18:43:33 +1100 Subject: [PATCH 037/100] feat(nodes): ItemStorageABC typevar no longer bound to pydantic.BaseModel This bound is totally unnecessary. There's no requirement for any implementation of `ItemStorageABC` to work only on pydantic models. --- invokeai/app/services/item_storage/item_storage_base.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/invokeai/app/services/item_storage/item_storage_base.py b/invokeai/app/services/item_storage/item_storage_base.py index c93edf5188d..d7366791594 100644 --- a/invokeai/app/services/item_storage/item_storage_base.py +++ b/invokeai/app/services/item_storage/item_storage_base.py @@ -1,9 +1,7 @@ from abc import ABC, abstractmethod from typing import Callable, Generic, TypeVar -from pydantic import BaseModel - -T = TypeVar("T", bound=BaseModel) +T = TypeVar("T") class ItemStorageABC(ABC, Generic[T]): From 25386a76ef52e7e6eafa524408cf6bb33e1ad356 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 18:44:30 +1100 Subject: [PATCH 038/100] tidy(nodes): do not refer to files as latents in `PickleStorageTorch` (again) --- .../services/pickle_storage/pickle_storage_torch.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/invokeai/app/services/pickle_storage/pickle_storage_torch.py b/invokeai/app/services/pickle_storage/pickle_storage_torch.py index de411bbf47d..16f0d7bb7ad 100644 --- a/invokeai/app/services/pickle_storage/pickle_storage_torch.py +++ b/invokeai/app/services/pickle_storage/pickle_storage_torch.py @@ -25,17 +25,17 @@ def start(self, invoker: Invoker) -> None: self._delete_all_items() def get(self, name: str) -> T: - latent_path = self._get_path(name) - return torch.load(latent_path) + file_path = self._get_path(name) + return torch.load(file_path) def save(self, name: str, data: T) -> None: self._output_folder.mkdir(parents=True, exist_ok=True) - latent_path = self._get_path(name) - torch.save(data, latent_path) + file_path = self._get_path(name) + torch.save(data, file_path) def delete(self, name: str) -> None: - latent_path = self._get_path(name) - latent_path.unlink() + file_path = self._get_path(name) + file_path.unlink() def _get_path(self, name: str) -> Path: return self._output_folder / name From fe0391c86bd819536c4f37f384792c66a1709883 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 19:39:03 +1100 Subject: [PATCH 039/100] feat(nodes): use `ItemStorageABC` for tensors and conditioning Turns out `ItemStorageABC` was almost identical to `PickleStorageBase`. Instead of maintaining separate classes, we can use `ItemStorageABC` for both. There's only one change needed - the `ItemStorageABC.set` method must return the newly stored item's ID. This allows us to let the service handle the responsibility of naming the item, but still create the requisite output objects during node execution. The naming implementation is improved here. It extracts the name of the generic and appends a UUID to that string when saving items. --- invokeai/app/api/dependencies.py | 10 +-- invokeai/app/services/invocation_services.py | 5 +- .../item_storage/item_storage_base.py | 2 +- .../item_storage_ephemeral_disk.py | 72 +++++++++++++++++++ .../item_storage_forward_cache.py | 61 ++++++++++++++++ .../item_storage/item_storage_memory.py | 3 +- .../pickle_storage/pickle_storage_base.py | 45 ------------ .../pickle_storage_forward_cache.py | 58 --------------- .../pickle_storage/pickle_storage_torch.py | 63 ---------------- .../app/services/shared/invocation_context.py | 30 +------- 10 files changed, 145 insertions(+), 204 deletions(-) create mode 100644 invokeai/app/services/item_storage/item_storage_ephemeral_disk.py create mode 100644 invokeai/app/services/item_storage/item_storage_forward_cache.py delete mode 100644 invokeai/app/services/pickle_storage/pickle_storage_base.py delete mode 100644 invokeai/app/services/pickle_storage/pickle_storage_forward_cache.py delete mode 100644 invokeai/app/services/pickle_storage/pickle_storage_torch.py diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 6bb0915cb6e..d6fd970a22d 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -4,9 +4,9 @@ import torch +from invokeai.app.services.item_storage.item_storage_ephemeral_disk import ItemStorageEphemeralDisk +from invokeai.app.services.item_storage.item_storage_forward_cache import ItemStorageForwardCache from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory -from invokeai.app.services.pickle_storage.pickle_storage_forward_cache import PickleStorageForwardCache -from invokeai.app.services.pickle_storage.pickle_storage_torch import PickleStorageTorch from invokeai.app.services.shared.sqlite.sqlite_util import init_db from invokeai.backend.model_manager.metadata import ModelMetadataStore from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData @@ -90,9 +90,9 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger image_records = SqliteImageRecordStorage(db=db) images = ImageService() invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) - tensors = PickleStorageForwardCache(PickleStorageTorch[torch.Tensor](output_folder / "tensors", "tensor")) - conditioning = PickleStorageForwardCache( - PickleStorageTorch[ConditioningFieldData](output_folder / "conditioning", "conditioning") + tensors = ItemStorageForwardCache(ItemStorageEphemeralDisk[torch.Tensor](output_folder / "tensors")) + conditioning = ItemStorageForwardCache( + ItemStorageEphemeralDisk[ConditioningFieldData](output_folder / "conditioning") ) model_manager = ModelManagerService(config, logger) model_record_service = ModelRecordServiceSQL(db=db) diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 81885781acb..69599d83a4b 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -29,7 +29,6 @@ from .model_manager.model_manager_base import ModelManagerServiceBase from .model_records import ModelRecordServiceBase from .names.names_base import NameServiceBase - from .pickle_storage.pickle_storage_base import PickleStorageBase from .session_processor.session_processor_base import SessionProcessorBase from .session_queue.session_queue_base import SessionQueueBase from .shared.graph import GraphExecutionState @@ -66,8 +65,8 @@ def __init__( names: "NameServiceBase", urls: "UrlServiceBase", workflow_records: "WorkflowRecordsStorageBase", - tensors: "PickleStorageBase[torch.Tensor]", - conditioning: "PickleStorageBase[ConditioningFieldData]", + tensors: "ItemStorageABC[torch.Tensor]", + conditioning: "ItemStorageABC[ConditioningFieldData]", ): self.board_images = board_images self.board_image_records = board_image_records diff --git a/invokeai/app/services/item_storage/item_storage_base.py b/invokeai/app/services/item_storage/item_storage_base.py index d7366791594..f2d62ea45fb 100644 --- a/invokeai/app/services/item_storage/item_storage_base.py +++ b/invokeai/app/services/item_storage/item_storage_base.py @@ -26,7 +26,7 @@ def get(self, item_id: str) -> T: pass @abstractmethod - def set(self, item: T) -> None: + def set(self, item: T) -> str: """ Sets the item. The id will be extracted based on id_field. :param item: the item to set diff --git a/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py b/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py new file mode 100644 index 00000000000..9843d1e54bc --- /dev/null +++ b/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py @@ -0,0 +1,72 @@ +import typing +from pathlib import Path +from typing import Optional, TypeVar + +import torch + +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC +from invokeai.app.util.misc import uuid_string + +T = TypeVar("T") + + +class ItemStorageEphemeralDisk(ItemStorageABC[T]): + """Provides arbitrary item storage with a disk-backed ephemeral storage. The storage is cleared at startup.""" + + def __init__(self, output_folder: Path): + super().__init__() + self._output_folder = output_folder + self._output_folder.mkdir(parents=True, exist_ok=True) + self.__item_class_name: Optional[str] = None + + @property + def _item_class_name(self) -> str: + if not self.__item_class_name: + # `__orig_class__` is not available in the constructor for some technical, undoubtedly very pythonic reason + self.__item_class_name = typing.get_args(self.__orig_class__)[0].__name__ # pyright: ignore [reportUnknownMemberType, reportGeneralTypeIssues] + return self.__item_class_name + + def start(self, invoker: Invoker) -> None: + self._invoker = invoker + self._delete_all_items() + + def get(self, item_id: str) -> T: + file_path = self._get_path(item_id) + return torch.load(file_path) # pyright: ignore [reportUnknownMemberType] + + def set(self, item: T) -> str: + self._output_folder.mkdir(parents=True, exist_ok=True) + item_id = f"{self._item_class_name}_{uuid_string()}" + file_path = self._get_path(item_id) + torch.save(item, file_path) # pyright: ignore [reportUnknownMemberType] + return item_id + + def delete(self, item_id: str) -> None: + file_path = self._get_path(item_id) + file_path.unlink() + + def _get_path(self, item_id: str) -> Path: + return self._output_folder / item_id + + def _delete_all_items(self) -> None: + """ + Deletes all pickled items from disk. + Must be called after we have access to `self._invoker` (e.g. in `start()`). + """ + + if not self._invoker: + raise ValueError("Invoker is not set. Must call `start()` first.") + + deleted_count = 0 + freed_space = 0 + for file in Path(self._output_folder).glob("*"): + if file.is_file(): + freed_space += file.stat().st_size + deleted_count += 1 + file.unlink() + if deleted_count > 0: + freed_space_in_mb = round(freed_space / 1024 / 1024, 2) + self._invoker.services.logger.info( + f"Deleted {deleted_count} {self._item_class_name} files (freed {freed_space_in_mb}MB)" + ) diff --git a/invokeai/app/services/item_storage/item_storage_forward_cache.py b/invokeai/app/services/item_storage/item_storage_forward_cache.py new file mode 100644 index 00000000000..d1fe8e13fa9 --- /dev/null +++ b/invokeai/app/services/item_storage/item_storage_forward_cache.py @@ -0,0 +1,61 @@ +from queue import Queue +from typing import Optional, TypeVar + +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC + +T = TypeVar("T") + + +class ItemStorageForwardCache(ItemStorageABC[T]): + """Provides a simple forward cache for an underlying storage. The cache is LRU and has a maximum size.""" + + def __init__(self, underlying_storage: ItemStorageABC[T], max_cache_size: int = 20): + super().__init__() + self._underlying_storage = underlying_storage + self._cache: dict[str, T] = {} + self._cache_ids = Queue[str]() + self._max_cache_size = max_cache_size + + def start(self, invoker: Invoker) -> None: + self._invoker = invoker + start_op = getattr(self._underlying_storage, "start", None) + if callable(start_op): + start_op(invoker) + + def stop(self, invoker: Invoker) -> None: + self._invoker = invoker + stop_op = getattr(self._underlying_storage, "stop", None) + if callable(stop_op): + stop_op(invoker) + + def get(self, item_id: str) -> T: + cache_item = self._get_cache(item_id) + if cache_item is not None: + return cache_item + + latent = self._underlying_storage.get(item_id) + self._set_cache(item_id, latent) + return latent + + def set(self, item: T) -> str: + item_id = self._underlying_storage.set(item) + self._set_cache(item_id, item) + self._on_changed(item) + return item_id + + def delete(self, item_id: str) -> None: + self._underlying_storage.delete(item_id) + if item_id in self._cache: + del self._cache[item_id] + self._on_deleted(item_id) + + def _get_cache(self, item_id: str) -> Optional[T]: + return None if item_id not in self._cache else self._cache[item_id] + + def _set_cache(self, item_id: str, data: T): + if item_id not in self._cache: + self._cache[item_id] = data + self._cache_ids.put(item_id) + if self._cache_ids.qsize() > self._max_cache_size: + self._cache.pop(self._cache_ids.get()) diff --git a/invokeai/app/services/item_storage/item_storage_memory.py b/invokeai/app/services/item_storage/item_storage_memory.py index d8dd0e06645..6d028745164 100644 --- a/invokeai/app/services/item_storage/item_storage_memory.py +++ b/invokeai/app/services/item_storage/item_storage_memory.py @@ -34,7 +34,7 @@ def get(self, item_id: str) -> T: self._items[item_id] = item return item - def set(self, item: T) -> None: + def set(self, item: T) -> str: item_id = getattr(item, self._id_field) if item_id in self._items: # If item already exists, remove it and add it to the end @@ -44,6 +44,7 @@ def set(self, item: T) -> None: self._items.popitem(last=False) self._items[item_id] = item self._on_changed(item) + return item_id def delete(self, item_id: str) -> None: # This is a no-op if the item doesn't exist. diff --git a/invokeai/app/services/pickle_storage/pickle_storage_base.py b/invokeai/app/services/pickle_storage/pickle_storage_base.py deleted file mode 100644 index 558b97c0f1b..00000000000 --- a/invokeai/app/services/pickle_storage/pickle_storage_base.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) - -from abc import ABC, abstractmethod -from typing import Callable, Generic, TypeVar - -T = TypeVar("T") - - -class PickleStorageBase(ABC, Generic[T]): - """Responsible for storing and retrieving non-serializable data using a pickler.""" - - _on_changed_callbacks: list[Callable[[T], None]] - _on_deleted_callbacks: list[Callable[[str], None]] - - def __init__(self) -> None: - self._on_changed_callbacks = [] - self._on_deleted_callbacks = [] - - @abstractmethod - def get(self, name: str) -> T: - pass - - @abstractmethod - def save(self, name: str, data: T) -> None: - pass - - @abstractmethod - def delete(self, name: str) -> None: - pass - - def on_changed(self, on_changed: Callable[[T], None]) -> None: - """Register a callback for when an item is changed""" - self._on_changed_callbacks.append(on_changed) - - def on_deleted(self, on_deleted: Callable[[str], None]) -> None: - """Register a callback for when an item is deleted""" - self._on_deleted_callbacks.append(on_deleted) - - def _on_changed(self, item: T) -> None: - for callback in self._on_changed_callbacks: - callback(item) - - def _on_deleted(self, item_id: str) -> None: - for callback in self._on_deleted_callbacks: - callback(item_id) diff --git a/invokeai/app/services/pickle_storage/pickle_storage_forward_cache.py b/invokeai/app/services/pickle_storage/pickle_storage_forward_cache.py deleted file mode 100644 index 3002d9e045d..00000000000 --- a/invokeai/app/services/pickle_storage/pickle_storage_forward_cache.py +++ /dev/null @@ -1,58 +0,0 @@ -from queue import Queue -from typing import Optional, TypeVar - -from invokeai.app.services.invoker import Invoker -from invokeai.app.services.pickle_storage.pickle_storage_base import PickleStorageBase - -T = TypeVar("T") - - -class PickleStorageForwardCache(PickleStorageBase[T]): - def __init__(self, underlying_storage: PickleStorageBase[T], max_cache_size: int = 20): - super().__init__() - self._underlying_storage = underlying_storage - self._cache: dict[str, T] = {} - self._cache_ids = Queue[str]() - self._max_cache_size = max_cache_size - - def start(self, invoker: Invoker) -> None: - self._invoker = invoker - start_op = getattr(self._underlying_storage, "start", None) - if callable(start_op): - start_op(invoker) - - def stop(self, invoker: Invoker) -> None: - self._invoker = invoker - stop_op = getattr(self._underlying_storage, "stop", None) - if callable(stop_op): - stop_op(invoker) - - def get(self, name: str) -> T: - cache_item = self._get_cache(name) - if cache_item is not None: - return cache_item - - latent = self._underlying_storage.get(name) - self._set_cache(name, latent) - return latent - - def save(self, name: str, data: T) -> None: - self._underlying_storage.save(name, data) - self._set_cache(name, data) - self._on_changed(data) - - def delete(self, name: str) -> None: - self._underlying_storage.delete(name) - if name in self._cache: - del self._cache[name] - self._on_deleted(name) - - def _get_cache(self, name: str) -> Optional[T]: - return None if name not in self._cache else self._cache[name] - - def _set_cache(self, name: str, data: T): - if name not in self._cache: - self._cache[name] = data - self._cache_ids.put(name) - if self._cache_ids.qsize() > self._max_cache_size: - self._cache.pop(self._cache_ids.get()) diff --git a/invokeai/app/services/pickle_storage/pickle_storage_torch.py b/invokeai/app/services/pickle_storage/pickle_storage_torch.py deleted file mode 100644 index 16f0d7bb7ad..00000000000 --- a/invokeai/app/services/pickle_storage/pickle_storage_torch.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) - -from pathlib import Path -from typing import TypeVar - -import torch - -from invokeai.app.services.invoker import Invoker -from invokeai.app.services.pickle_storage.pickle_storage_base import PickleStorageBase - -T = TypeVar("T") - - -class PickleStorageTorch(PickleStorageBase[T]): - """Responsible for storing and retrieving non-serializable data using `torch.save` and `torch.load`.""" - - def __init__(self, output_folder: Path, item_type_name: "str"): - super().__init__() - self._output_folder = output_folder - self._output_folder.mkdir(parents=True, exist_ok=True) - self._item_type_name = item_type_name - - def start(self, invoker: Invoker) -> None: - self._invoker = invoker - self._delete_all_items() - - def get(self, name: str) -> T: - file_path = self._get_path(name) - return torch.load(file_path) - - def save(self, name: str, data: T) -> None: - self._output_folder.mkdir(parents=True, exist_ok=True) - file_path = self._get_path(name) - torch.save(data, file_path) - - def delete(self, name: str) -> None: - file_path = self._get_path(name) - file_path.unlink() - - def _get_path(self, name: str) -> Path: - return self._output_folder / name - - def _delete_all_items(self) -> None: - """ - Deletes all pickled items from disk. - Must be called after we have access to `self._invoker` (e.g. in `start()`). - """ - - if not self._invoker: - raise ValueError("Invoker is not set. Must call `start()` first.") - - deleted_count = 0 - freed_space = 0 - for file in Path(self._output_folder).glob("*"): - if file.is_file(): - freed_space += file.stat().st_size - deleted_count += 1 - file.unlink() - if deleted_count > 0: - freed_space_in_mb = round(freed_space / 1024 / 1024, 2) - self._invoker.services.logger.info( - f"Deleted {deleted_count} {self._item_type_name} files (freed {freed_space_in_mb}MB)" - ) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 6756b1f5c6c..baff47a3df4 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -12,7 +12,6 @@ from invokeai.app.services.images.images_common import ImageDTO from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID -from invokeai.app.util.misc import uuid_string from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.backend.model_management.model_manager import ModelInfo from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType @@ -224,26 +223,7 @@ def save(self, tensor: Tensor) -> str: :param tensor: The tensor to save. """ - # Previously, we added a suffix indicating the type of Tensor we were saving, e.g. - # "mask", "noise", "masked_latents", etc. - # - # Retaining that capability in this wrapper would require either many different methods - # to save tensors, or extra args for this method. Instead of complicating the API, we - # will use the same naming scheme for all tensors. - # - # This has a very minor impact as we don't use them after a session completes. - - # Previously, invocations chose the name for their tensors. This is a bit risky, so we - # will generate a name for them instead. We use a uuid to ensure the name is unique. - # - # Because the name of the tensors file will includes the session and invocation IDs, - # we don't need to worry about collisions. A truncated UUIDv4 is fine. - - name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}" - self._services.tensors.save( - name=name, - data=tensor, - ) + name = self._services.tensors.set(item=tensor) return name def get(self, tensor_name: str) -> Tensor: @@ -263,13 +243,7 @@ def save(self, conditioning_data: ConditioningFieldData) -> str: :param conditioning_context_data: The conditioning data to save. """ - # See comment in TensorsInterface.save for why we generate the name here. - - name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}" - self._services.conditioning.save( - name=name, - data=conditioning_data, - ) + name = self._services.conditioning.set(item=conditioning_data) return name def get(self, conditioning_name: str) -> ConditioningFieldData: From 7fe5283e7400b3f181a200367f3163ee9d8db637 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 19:50:30 +1100 Subject: [PATCH 040/100] feat(nodes): create helper function to generate the item ID --- .../app/services/item_storage/item_storage_ephemeral_disk.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py b/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py index 9843d1e54bc..377c9c39b3e 100644 --- a/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py +++ b/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py @@ -37,7 +37,7 @@ def get(self, item_id: str) -> T: def set(self, item: T) -> str: self._output_folder.mkdir(parents=True, exist_ok=True) - item_id = f"{self._item_class_name}_{uuid_string()}" + item_id = self._new_item_id() file_path = self._get_path(item_id) torch.save(item, file_path) # pyright: ignore [reportUnknownMemberType] return item_id @@ -49,6 +49,9 @@ def delete(self, item_id: str) -> None: def _get_path(self, item_id: str) -> Path: return self._output_folder / item_id + def _new_item_id(self) -> str: + return f"{self._item_class_name}_{uuid_string()}" + def _delete_all_items(self) -> None: """ Deletes all pickled items from disk. From 54a67459bf1b6d5ea1a21d2f8cdb8533ed05a7c6 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 19:51:04 +1100 Subject: [PATCH 041/100] feat(nodes): support custom save and load functions in `ItemStorageEphemeralDisk` --- .../item_storage/item_storage_common.py | 10 +++++++ .../item_storage_ephemeral_disk.py | 26 +++++++++++++++---- 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/invokeai/app/services/item_storage/item_storage_common.py b/invokeai/app/services/item_storage/item_storage_common.py index 8fd677c71b7..7f9bd7bd4ef 100644 --- a/invokeai/app/services/item_storage/item_storage_common.py +++ b/invokeai/app/services/item_storage/item_storage_common.py @@ -1,5 +1,15 @@ +from pathlib import Path +from typing import Callable, TypeAlias, TypeVar + + class ItemNotFoundError(KeyError): """Raised when an item is not found in storage""" def __init__(self, item_id: str) -> None: super().__init__(f"Item with id {item_id} not found") + + +T = TypeVar("T") + +SaveFunc: TypeAlias = Callable[[T, Path], None] +LoadFunc: TypeAlias = Callable[[Path], T] diff --git a/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py b/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py index 377c9c39b3e..4dc67129dac 100644 --- a/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py +++ b/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py @@ -6,18 +6,31 @@ from invokeai.app.services.invoker import Invoker from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC +from invokeai.app.services.item_storage.item_storage_common import LoadFunc, SaveFunc from invokeai.app.util.misc import uuid_string T = TypeVar("T") class ItemStorageEphemeralDisk(ItemStorageABC[T]): - """Provides arbitrary item storage with a disk-backed ephemeral storage. The storage is cleared at startup.""" - - def __init__(self, output_folder: Path): + """Provides a disk-backed ephemeral storage. The storage is cleared at startup. + + :param output_folder: The folder where the items will be stored + :param save: The function to use to save the items to disk [torch.save] + :param load: The function to use to load the items from disk [torch.load] + """ + + def __init__( + self, + output_folder: Path, + save: SaveFunc[T] = torch.save, # pyright: ignore [reportUnknownMemberType] + load: LoadFunc[T] = torch.load, # pyright: ignore [reportUnknownMemberType] + ): super().__init__() self._output_folder = output_folder self._output_folder.mkdir(parents=True, exist_ok=True) + self._save = save + self._load = load self.__item_class_name: Optional[str] = None @property @@ -33,13 +46,13 @@ def start(self, invoker: Invoker) -> None: def get(self, item_id: str) -> T: file_path = self._get_path(item_id) - return torch.load(file_path) # pyright: ignore [reportUnknownMemberType] + return self._load(file_path) def set(self, item: T) -> str: self._output_folder.mkdir(parents=True, exist_ok=True) item_id = self._new_item_id() file_path = self._get_path(item_id) - torch.save(item, file_path) # pyright: ignore [reportUnknownMemberType] + self._save(item, file_path) return item_id def delete(self, item_id: str) -> None: @@ -58,6 +71,9 @@ def _delete_all_items(self) -> None: Must be called after we have access to `self._invoker` (e.g. in `start()`). """ + # We could try using a temporary directory here, but they aren't cleared in the event of a crash, so we'd have + # to manually clear them on startup anyways. This is a bit simpler and more reliable. + if not self._invoker: raise ValueError("Invoker is not set. Must call `start()` first.") From 8b6e32269702a837d4f785cf56ae545f69df43ef Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 22:54:52 +1100 Subject: [PATCH 042/100] feat(nodes): support custom exception in ephemeral disk storage --- .../item_storage/item_storage_ephemeral_disk.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py b/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py index 4dc67129dac..97c767c87d7 100644 --- a/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py +++ b/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py @@ -1,12 +1,12 @@ import typing from pathlib import Path -from typing import Optional, TypeVar +from typing import Optional, Type, TypeVar import torch from invokeai.app.services.invoker import Invoker from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC -from invokeai.app.services.item_storage.item_storage_common import LoadFunc, SaveFunc +from invokeai.app.services.item_storage.item_storage_common import ItemNotFoundError, LoadFunc, SaveFunc from invokeai.app.util.misc import uuid_string T = TypeVar("T") @@ -18,6 +18,7 @@ class ItemStorageEphemeralDisk(ItemStorageABC[T]): :param output_folder: The folder where the items will be stored :param save: The function to use to save the items to disk [torch.save] :param load: The function to use to load the items from disk [torch.load] + :param load_exc: The exception that is raised when an item is not found [FileNotFoundError] """ def __init__( @@ -25,12 +26,14 @@ def __init__( output_folder: Path, save: SaveFunc[T] = torch.save, # pyright: ignore [reportUnknownMemberType] load: LoadFunc[T] = torch.load, # pyright: ignore [reportUnknownMemberType] + load_exc: Type[Exception] = FileNotFoundError, ): super().__init__() self._output_folder = output_folder self._output_folder.mkdir(parents=True, exist_ok=True) self._save = save self._load = load + self._load_exc = load_exc self.__item_class_name: Optional[str] = None @property @@ -46,7 +49,10 @@ def start(self, invoker: Invoker) -> None: def get(self, item_id: str) -> T: file_path = self._get_path(item_id) - return self._load(file_path) + try: + return self._load(file_path) + except self._load_exc as e: + raise ItemNotFoundError(item_id) from e def set(self, item: T) -> str: self._output_folder.mkdir(parents=True, exist_ok=True) From 06429028c8015d050fe3cb87951c6362e8702ac4 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 23:30:46 +1100 Subject: [PATCH 043/100] revert(nodes): revert making tensors/conditioning use item storage Turns out they are just different enough in purpose that the implementations would be rather unintuitive. I've made a separate ObjectSerializer service to handle tensors and conditioning. Refined the class a bit too. --- invokeai/app/api/dependencies.py | 10 +- invokeai/app/invocations/latent.py | 24 ++--- invokeai/app/invocations/primitives.py | 2 +- invokeai/app/services/invocation_services.py | 6 +- .../item_storage/item_storage_base.py | 8 +- .../item_storage/item_storage_common.py | 10 -- .../item_storage_ephemeral_disk.py | 97 ------------------- .../item_storage_forward_cache.py | 61 ------------ .../item_storage/item_storage_memory.py | 3 +- .../object_serializer_base.py | 53 ++++++++++ .../object_serializer_common.py | 5 + .../object_serializer_ephemeral_disk.py | 84 ++++++++++++++++ .../object_serializer_forward_cache.py | 61 ++++++++++++ .../app/services/shared/invocation_context.py | 24 ++--- 14 files changed, 243 insertions(+), 205 deletions(-) delete mode 100644 invokeai/app/services/item_storage/item_storage_ephemeral_disk.py delete mode 100644 invokeai/app/services/item_storage/item_storage_forward_cache.py create mode 100644 invokeai/app/services/object_serializer/object_serializer_base.py create mode 100644 invokeai/app/services/object_serializer/object_serializer_common.py create mode 100644 invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py create mode 100644 invokeai/app/services/object_serializer/object_serializer_forward_cache.py diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index d6fd970a22d..0c80494616f 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -4,9 +4,9 @@ import torch -from invokeai.app.services.item_storage.item_storage_ephemeral_disk import ItemStorageEphemeralDisk -from invokeai.app.services.item_storage.item_storage_forward_cache import ItemStorageForwardCache from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory +from invokeai.app.services.object_serializer.object_serializer_ephemeral_disk import ObjectSerializerEphemeralDisk +from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache from invokeai.app.services.shared.sqlite.sqlite_util import init_db from invokeai.backend.model_manager.metadata import ModelMetadataStore from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData @@ -90,9 +90,9 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger image_records = SqliteImageRecordStorage(db=db) images = ImageService() invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) - tensors = ItemStorageForwardCache(ItemStorageEphemeralDisk[torch.Tensor](output_folder / "tensors")) - conditioning = ItemStorageForwardCache( - ItemStorageEphemeralDisk[ConditioningFieldData](output_folder / "conditioning") + tensors = ObjectSerializerForwardCache(ObjectSerializerEphemeralDisk[torch.Tensor](output_folder / "tensors")) + conditioning = ObjectSerializerForwardCache( + ObjectSerializerEphemeralDisk[ConditioningFieldData](output_folder / "conditioning") ) model_manager = ModelManagerService(config, logger) model_record_service = ModelRecordServiceSQL(db=db) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 94440d3e2aa..4137ab6e2f6 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -304,11 +304,11 @@ def get_conditioning_data( unet, seed, ) -> ConditioningData: - positive_cond_data = context.conditioning.get(self.positive_conditioning.conditioning_name) + positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name) c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) extra_conditioning_info = c.extra_conditioning - negative_cond_data = context.conditioning.get(self.negative_conditioning.conditioning_name) + negative_cond_data = context.conditioning.load(self.negative_conditioning.conditioning_name) uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) conditioning_data = ConditioningData( @@ -621,10 +621,10 @@ def prep_inpaint_mask(self, context: InvocationContext, latents): if self.denoise_mask is None: return None, None - mask = context.tensors.get(self.denoise_mask.mask_name) + mask = context.tensors.load(self.denoise_mask.mask_name) mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) if self.denoise_mask.masked_latents_name is not None: - masked_latents = context.tensors.get(self.denoise_mask.masked_latents_name) + masked_latents = context.tensors.load(self.denoise_mask.masked_latents_name) else: masked_latents = None @@ -636,11 +636,11 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: seed = None noise = None if self.noise is not None: - noise = context.tensors.get(self.noise.latents_name) + noise = context.tensors.load(self.noise.latents_name) seed = self.noise.seed if self.latents is not None: - latents = context.tensors.get(self.latents.latents_name) + latents = context.tensors.load(self.latents.latents_name) if seed is None: seed = self.latents.seed @@ -779,7 +779,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): @torch.no_grad() def invoke(self, context: InvocationContext) -> ImageOutput: - latents = context.tensors.get(self.latents.latents_name) + latents = context.tensors.load(self.latents.latents_name) vae_info = context.models.load(**self.vae.vae.model_dump()) @@ -870,7 +870,7 @@ class ResizeLatentsInvocation(BaseInvocation): antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.tensors.get(self.latents.latents_name) + latents = context.tensors.load(self.latents.latents_name) # TODO: device = choose_torch_device() @@ -911,7 +911,7 @@ class ScaleLatentsInvocation(BaseInvocation): antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.tensors.get(self.latents.latents_name) + latents = context.tensors.load(self.latents.latents_name) # TODO: device = choose_torch_device() @@ -1048,8 +1048,8 @@ class BlendLatentsInvocation(BaseInvocation): alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents_a = context.tensors.get(self.latents_a.latents_name) - latents_b = context.tensors.get(self.latents_b.latents_name) + latents_a = context.tensors.load(self.latents_a.latents_name) + latents_b = context.tensors.load(self.latents_b.latents_name) if latents_a.shape != latents_b.shape: raise Exception("Latents to blend must be the same size.") @@ -1149,7 +1149,7 @@ class CropLatentsCoreInvocation(BaseInvocation): ) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.tensors.get(self.latents.latents_name) + latents = context.tensors.load(self.latents.latents_name) x1 = self.x // LATENT_SCALE_FACTOR y1 = self.y // LATENT_SCALE_FACTOR diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index 082d5432ccf..d0f95c92d02 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -344,7 +344,7 @@ class LatentsInvocation(BaseInvocation): latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.tensors.get(self.latents.latents_name) + latents = context.tensors.load(self.latents.latents_name) return LatentsOutput.build(self.latents.latents_name, latents) diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 69599d83a4b..e893be87636 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -3,6 +3,8 @@ from typing import TYPE_CHECKING +from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase + if TYPE_CHECKING: from logging import Logger @@ -65,8 +67,8 @@ def __init__( names: "NameServiceBase", urls: "UrlServiceBase", workflow_records: "WorkflowRecordsStorageBase", - tensors: "ItemStorageABC[torch.Tensor]", - conditioning: "ItemStorageABC[ConditioningFieldData]", + tensors: "ObjectSerializerBase[torch.Tensor]", + conditioning: "ObjectSerializerBase[ConditioningFieldData]", ): self.board_images = board_images self.board_image_records = board_image_records diff --git a/invokeai/app/services/item_storage/item_storage_base.py b/invokeai/app/services/item_storage/item_storage_base.py index f2d62ea45fb..ef227ba241c 100644 --- a/invokeai/app/services/item_storage/item_storage_base.py +++ b/invokeai/app/services/item_storage/item_storage_base.py @@ -1,7 +1,9 @@ from abc import ABC, abstractmethod from typing import Callable, Generic, TypeVar -T = TypeVar("T") +from pydantic import BaseModel + +T = TypeVar("T", bound=BaseModel) class ItemStorageABC(ABC, Generic[T]): @@ -26,9 +28,9 @@ def get(self, item_id: str) -> T: pass @abstractmethod - def set(self, item: T) -> str: + def set(self, item: T) -> None: """ - Sets the item. The id will be extracted based on id_field. + Sets the item. :param item: the item to set """ pass diff --git a/invokeai/app/services/item_storage/item_storage_common.py b/invokeai/app/services/item_storage/item_storage_common.py index 7f9bd7bd4ef..8fd677c71b7 100644 --- a/invokeai/app/services/item_storage/item_storage_common.py +++ b/invokeai/app/services/item_storage/item_storage_common.py @@ -1,15 +1,5 @@ -from pathlib import Path -from typing import Callable, TypeAlias, TypeVar - - class ItemNotFoundError(KeyError): """Raised when an item is not found in storage""" def __init__(self, item_id: str) -> None: super().__init__(f"Item with id {item_id} not found") - - -T = TypeVar("T") - -SaveFunc: TypeAlias = Callable[[T, Path], None] -LoadFunc: TypeAlias = Callable[[Path], T] diff --git a/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py b/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py deleted file mode 100644 index 97c767c87d7..00000000000 --- a/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py +++ /dev/null @@ -1,97 +0,0 @@ -import typing -from pathlib import Path -from typing import Optional, Type, TypeVar - -import torch - -from invokeai.app.services.invoker import Invoker -from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC -from invokeai.app.services.item_storage.item_storage_common import ItemNotFoundError, LoadFunc, SaveFunc -from invokeai.app.util.misc import uuid_string - -T = TypeVar("T") - - -class ItemStorageEphemeralDisk(ItemStorageABC[T]): - """Provides a disk-backed ephemeral storage. The storage is cleared at startup. - - :param output_folder: The folder where the items will be stored - :param save: The function to use to save the items to disk [torch.save] - :param load: The function to use to load the items from disk [torch.load] - :param load_exc: The exception that is raised when an item is not found [FileNotFoundError] - """ - - def __init__( - self, - output_folder: Path, - save: SaveFunc[T] = torch.save, # pyright: ignore [reportUnknownMemberType] - load: LoadFunc[T] = torch.load, # pyright: ignore [reportUnknownMemberType] - load_exc: Type[Exception] = FileNotFoundError, - ): - super().__init__() - self._output_folder = output_folder - self._output_folder.mkdir(parents=True, exist_ok=True) - self._save = save - self._load = load - self._load_exc = load_exc - self.__item_class_name: Optional[str] = None - - @property - def _item_class_name(self) -> str: - if not self.__item_class_name: - # `__orig_class__` is not available in the constructor for some technical, undoubtedly very pythonic reason - self.__item_class_name = typing.get_args(self.__orig_class__)[0].__name__ # pyright: ignore [reportUnknownMemberType, reportGeneralTypeIssues] - return self.__item_class_name - - def start(self, invoker: Invoker) -> None: - self._invoker = invoker - self._delete_all_items() - - def get(self, item_id: str) -> T: - file_path = self._get_path(item_id) - try: - return self._load(file_path) - except self._load_exc as e: - raise ItemNotFoundError(item_id) from e - - def set(self, item: T) -> str: - self._output_folder.mkdir(parents=True, exist_ok=True) - item_id = self._new_item_id() - file_path = self._get_path(item_id) - self._save(item, file_path) - return item_id - - def delete(self, item_id: str) -> None: - file_path = self._get_path(item_id) - file_path.unlink() - - def _get_path(self, item_id: str) -> Path: - return self._output_folder / item_id - - def _new_item_id(self) -> str: - return f"{self._item_class_name}_{uuid_string()}" - - def _delete_all_items(self) -> None: - """ - Deletes all pickled items from disk. - Must be called after we have access to `self._invoker` (e.g. in `start()`). - """ - - # We could try using a temporary directory here, but they aren't cleared in the event of a crash, so we'd have - # to manually clear them on startup anyways. This is a bit simpler and more reliable. - - if not self._invoker: - raise ValueError("Invoker is not set. Must call `start()` first.") - - deleted_count = 0 - freed_space = 0 - for file in Path(self._output_folder).glob("*"): - if file.is_file(): - freed_space += file.stat().st_size - deleted_count += 1 - file.unlink() - if deleted_count > 0: - freed_space_in_mb = round(freed_space / 1024 / 1024, 2) - self._invoker.services.logger.info( - f"Deleted {deleted_count} {self._item_class_name} files (freed {freed_space_in_mb}MB)" - ) diff --git a/invokeai/app/services/item_storage/item_storage_forward_cache.py b/invokeai/app/services/item_storage/item_storage_forward_cache.py deleted file mode 100644 index d1fe8e13fa9..00000000000 --- a/invokeai/app/services/item_storage/item_storage_forward_cache.py +++ /dev/null @@ -1,61 +0,0 @@ -from queue import Queue -from typing import Optional, TypeVar - -from invokeai.app.services.invoker import Invoker -from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC - -T = TypeVar("T") - - -class ItemStorageForwardCache(ItemStorageABC[T]): - """Provides a simple forward cache for an underlying storage. The cache is LRU and has a maximum size.""" - - def __init__(self, underlying_storage: ItemStorageABC[T], max_cache_size: int = 20): - super().__init__() - self._underlying_storage = underlying_storage - self._cache: dict[str, T] = {} - self._cache_ids = Queue[str]() - self._max_cache_size = max_cache_size - - def start(self, invoker: Invoker) -> None: - self._invoker = invoker - start_op = getattr(self._underlying_storage, "start", None) - if callable(start_op): - start_op(invoker) - - def stop(self, invoker: Invoker) -> None: - self._invoker = invoker - stop_op = getattr(self._underlying_storage, "stop", None) - if callable(stop_op): - stop_op(invoker) - - def get(self, item_id: str) -> T: - cache_item = self._get_cache(item_id) - if cache_item is not None: - return cache_item - - latent = self._underlying_storage.get(item_id) - self._set_cache(item_id, latent) - return latent - - def set(self, item: T) -> str: - item_id = self._underlying_storage.set(item) - self._set_cache(item_id, item) - self._on_changed(item) - return item_id - - def delete(self, item_id: str) -> None: - self._underlying_storage.delete(item_id) - if item_id in self._cache: - del self._cache[item_id] - self._on_deleted(item_id) - - def _get_cache(self, item_id: str) -> Optional[T]: - return None if item_id not in self._cache else self._cache[item_id] - - def _set_cache(self, item_id: str, data: T): - if item_id not in self._cache: - self._cache[item_id] = data - self._cache_ids.put(item_id) - if self._cache_ids.qsize() > self._max_cache_size: - self._cache.pop(self._cache_ids.get()) diff --git a/invokeai/app/services/item_storage/item_storage_memory.py b/invokeai/app/services/item_storage/item_storage_memory.py index 6d028745164..d8dd0e06645 100644 --- a/invokeai/app/services/item_storage/item_storage_memory.py +++ b/invokeai/app/services/item_storage/item_storage_memory.py @@ -34,7 +34,7 @@ def get(self, item_id: str) -> T: self._items[item_id] = item return item - def set(self, item: T) -> str: + def set(self, item: T) -> None: item_id = getattr(item, self._id_field) if item_id in self._items: # If item already exists, remove it and add it to the end @@ -44,7 +44,6 @@ def set(self, item: T) -> str: self._items.popitem(last=False) self._items[item_id] = item self._on_changed(item) - return item_id def delete(self, item_id: str) -> None: # This is a no-op if the item doesn't exist. diff --git a/invokeai/app/services/object_serializer/object_serializer_base.py b/invokeai/app/services/object_serializer/object_serializer_base.py new file mode 100644 index 00000000000..b01a641d8fb --- /dev/null +++ b/invokeai/app/services/object_serializer/object_serializer_base.py @@ -0,0 +1,53 @@ +from abc import ABC, abstractmethod +from typing import Callable, Generic, TypeVar + +T = TypeVar("T") + + +class ObjectSerializerBase(ABC, Generic[T]): + """Saves and loads arbitrary python objects.""" + + def __init__(self) -> None: + self._on_saved_callbacks: list[Callable[[str, T], None]] = [] + self._on_deleted_callbacks: list[Callable[[str], None]] = [] + + @abstractmethod + def load(self, name: str) -> T: + """ + Loads the object. + :param name: The name of the object to load. + :raises ObjectNotFoundError: if the object is not found + """ + pass + + @abstractmethod + def save(self, obj: T) -> str: + """ + Saves the object, returning its name. + :param obj: The object to save. + """ + pass + + @abstractmethod + def delete(self, name: str) -> None: + """ + Deletes the object, if it exists. + :param name: The name of the object to delete. + """ + pass + + def on_saved(self, on_saved: Callable[[str, T], None]) -> None: + """Register a callback for when an object is saved""" + self._on_saved_callbacks.append(on_saved) + + def on_deleted(self, on_deleted: Callable[[str], None]) -> None: + """Register a callback for when an object is deleted""" + self._on_deleted_callbacks.append(on_deleted) + + def _on_saved(self, name: str, obj: T) -> None: + for callback in self._on_saved_callbacks: + callback(name, obj) + + def _on_deleted(self, name: str) -> None: + for callback in self._on_deleted_callbacks: + callback(name) diff --git a/invokeai/app/services/object_serializer/object_serializer_common.py b/invokeai/app/services/object_serializer/object_serializer_common.py new file mode 100644 index 00000000000..7057386541f --- /dev/null +++ b/invokeai/app/services/object_serializer/object_serializer_common.py @@ -0,0 +1,5 @@ +class ObjectNotFoundError(KeyError): + """Raised when an object is not found while loading""" + + def __init__(self, name: str) -> None: + super().__init__(f"Object with name {name} not found") diff --git a/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py b/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py new file mode 100644 index 00000000000..afa868b157f --- /dev/null +++ b/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py @@ -0,0 +1,84 @@ +import typing +from pathlib import Path +from typing import Optional, TypeVar + +import torch + +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase +from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError +from invokeai.app.util.misc import uuid_string + +T = TypeVar("T") + + +class ObjectSerializerEphemeralDisk(ObjectSerializerBase[T]): + """Provides a disk-backed ephemeral storage for arbitrary python objects. The storage is cleared at startup. + + :param output_folder: The folder where the objects will be stored + """ + + def __init__(self, output_dir: Path): + super().__init__() + self._output_dir = output_dir + self._output_dir.mkdir(parents=True, exist_ok=True) + self.__obj_class_name: Optional[str] = None + + def start(self, invoker: Invoker) -> None: + self._invoker = invoker + self._delete_all() + + def load(self, name: str) -> T: + file_path = self._get_path(name) + try: + return torch.load(file_path) # pyright: ignore [reportUnknownMemberType] + except FileNotFoundError as e: + raise ObjectNotFoundError(name) from e + + def save(self, obj: T) -> str: + name = self._new_name() + file_path = self._get_path(name) + torch.save(obj, file_path) # pyright: ignore [reportUnknownMemberType] + return name + + def delete(self, name: str) -> None: + file_path = self._get_path(name) + file_path.unlink() + + @property + def _obj_class_name(self) -> str: + if not self.__obj_class_name: + # `__orig_class__` is not available in the constructor for some technical, undoubtedly very pythonic reason + self.__obj_class_name = typing.get_args(self.__orig_class__)[0].__name__ # pyright: ignore [reportUnknownMemberType, reportGeneralTypeIssues] + return self.__obj_class_name + + def _get_path(self, name: str) -> Path: + return self._output_dir / name + + def _new_name(self) -> str: + return f"{self._obj_class_name}_{uuid_string()}" + + def _delete_all(self) -> None: + """ + Deletes all objects from disk. + Must be called after we have access to `self._invoker` (e.g. in `start()`). + """ + + # We could try using a temporary directory here, but they aren't cleared in the event of a crash, so we'd have + # to manually clear them on startup anyways. This is a bit simpler and more reliable. + + if not self._invoker: + raise ValueError("Invoker is not set. Must call `start()` first.") + + deleted_count = 0 + freed_space = 0 + for file in Path(self._output_dir).glob("*"): + if file.is_file(): + freed_space += file.stat().st_size + deleted_count += 1 + file.unlink() + if deleted_count > 0: + freed_space_in_mb = round(freed_space / 1024 / 1024, 2) + self._invoker.services.logger.info( + f"Deleted {deleted_count} {self._obj_class_name} files (freed {freed_space_in_mb}MB)" + ) diff --git a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py new file mode 100644 index 00000000000..40e34e65406 --- /dev/null +++ b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py @@ -0,0 +1,61 @@ +from queue import Queue +from typing import Optional, TypeVar + +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase + +T = TypeVar("T") + + +class ObjectSerializerForwardCache(ObjectSerializerBase[T]): + """Provides a simple forward cache for an underlying storage. The cache is LRU and has a maximum size.""" + + def __init__(self, underlying_storage: ObjectSerializerBase[T], max_cache_size: int = 20): + super().__init__() + self._underlying_storage = underlying_storage + self._cache: dict[str, T] = {} + self._cache_ids = Queue[str]() + self._max_cache_size = max_cache_size + + def start(self, invoker: Invoker) -> None: + self._invoker = invoker + start_op = getattr(self._underlying_storage, "start", None) + if callable(start_op): + start_op(invoker) + + def stop(self, invoker: Invoker) -> None: + self._invoker = invoker + stop_op = getattr(self._underlying_storage, "stop", None) + if callable(stop_op): + stop_op(invoker) + + def load(self, name: str) -> T: + cache_item = self._get_cache(name) + if cache_item is not None: + return cache_item + + latent = self._underlying_storage.load(name) + self._set_cache(name, latent) + return latent + + def save(self, obj: T) -> str: + name = self._underlying_storage.save(obj) + self._set_cache(name, obj) + self._on_saved(name, obj) + return name + + def delete(self, name: str) -> None: + self._underlying_storage.delete(name) + if name in self._cache: + del self._cache[name] + self._on_deleted(name) + + def _get_cache(self, name: str) -> Optional[T]: + return None if name not in self._cache else self._cache[name] + + def _set_cache(self, name: str, data: T): + if name not in self._cache: + self._cache[name] = data + self._cache_ids.put(name) + if self._cache_ids.qsize() > self._max_cache_size: + self._cache.pop(self._cache_ids.get()) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index baff47a3df4..8c5a821fd0f 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -223,16 +223,16 @@ def save(self, tensor: Tensor) -> str: :param tensor: The tensor to save. """ - name = self._services.tensors.set(item=tensor) - return name + tensor_id = self._services.tensors.save(obj=tensor) + return tensor_id - def get(self, tensor_name: str) -> Tensor: + def load(self, name: str) -> Tensor: """ - Gets a tensor by name. + Loads a tensor by name. - :param tensor_name: The name of the tensor to get. + :param name: The name of the tensor to load. """ - return self._services.tensors.get(tensor_name) + return self._services.tensors.load(name) class ConditioningInterface(InvocationContextInterface): @@ -243,17 +243,17 @@ def save(self, conditioning_data: ConditioningFieldData) -> str: :param conditioning_context_data: The conditioning data to save. """ - name = self._services.conditioning.set(item=conditioning_data) - return name + conditioning_id = self._services.conditioning.save(obj=conditioning_data) + return conditioning_id - def get(self, conditioning_name: str) -> ConditioningFieldData: + def load(self, name: str) -> ConditioningFieldData: """ - Gets conditioning data by name. + Loads conditioning data by name. - :param conditioning_name: The name of the conditioning data to get. + :param name: The name of the conditioning data to load. """ - return self._services.conditioning.get(conditioning_name) + return self._services.conditioning.load(name) class ModelsInterface(InvocationContextInterface): From 55fa7855610bc1626c9a3ec7a7cbb56833bdcd4e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 23:43:22 +1100 Subject: [PATCH 044/100] tidy(nodes): remove object serializer on_saved It's unused. --- .../services/object_serializer/object_serializer_base.py | 9 --------- .../object_serializer/object_serializer_forward_cache.py | 1 - 2 files changed, 10 deletions(-) diff --git a/invokeai/app/services/object_serializer/object_serializer_base.py b/invokeai/app/services/object_serializer/object_serializer_base.py index b01a641d8fb..ff19b4a039d 100644 --- a/invokeai/app/services/object_serializer/object_serializer_base.py +++ b/invokeai/app/services/object_serializer/object_serializer_base.py @@ -8,7 +8,6 @@ class ObjectSerializerBase(ABC, Generic[T]): """Saves and loads arbitrary python objects.""" def __init__(self) -> None: - self._on_saved_callbacks: list[Callable[[str, T], None]] = [] self._on_deleted_callbacks: list[Callable[[str], None]] = [] @abstractmethod @@ -36,18 +35,10 @@ def delete(self, name: str) -> None: """ pass - def on_saved(self, on_saved: Callable[[str, T], None]) -> None: - """Register a callback for when an object is saved""" - self._on_saved_callbacks.append(on_saved) - def on_deleted(self, on_deleted: Callable[[str], None]) -> None: """Register a callback for when an object is deleted""" self._on_deleted_callbacks.append(on_deleted) - def _on_saved(self, name: str, obj: T) -> None: - for callback in self._on_saved_callbacks: - callback(name, obj) - def _on_deleted(self, name: str) -> None: for callback in self._on_deleted_callbacks: callback(name) diff --git a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py index 40e34e65406..2a4ecdd844b 100644 --- a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py +++ b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py @@ -41,7 +41,6 @@ def load(self, name: str) -> T: def save(self, obj: T) -> str: name = self._underlying_storage.save(obj) self._set_cache(name, obj) - self._on_saved(name, obj) return name def delete(self, name: str) -> None: From 83ddcc5f3a9f74ee7fc680346660ecb9921c43cf Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 8 Feb 2024 00:18:58 +1100 Subject: [PATCH 045/100] feat(nodes): allow `_delete_all` in obj serializer to be called at any time `_delete_all` logged how many items it deleted, and had to be called _after_ service start bc it needed access to logger. Move the logger call to the startup method and return the the deleted stats from `_delete_all`. This lets `_delete_all` be called at any time. --- .../object_serializer_ephemeral_disk.py | 26 +++++++++++-------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py b/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py index afa868b157f..9545d1714d7 100644 --- a/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py +++ b/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py @@ -1,4 +1,5 @@ import typing +from dataclasses import dataclass from pathlib import Path from typing import Optional, TypeVar @@ -12,6 +13,12 @@ T = TypeVar("T") +@dataclass +class DeleteAllResult: + deleted_count: int + freed_space_bytes: float + + class ObjectSerializerEphemeralDisk(ObjectSerializerBase[T]): """Provides a disk-backed ephemeral storage for arbitrary python objects. The storage is cleared at startup. @@ -26,7 +33,12 @@ def __init__(self, output_dir: Path): def start(self, invoker: Invoker) -> None: self._invoker = invoker - self._delete_all() + delete_all_result = self._delete_all() + if delete_all_result.deleted_count > 0: + freed_space_in_mb = round(delete_all_result.freed_space_bytes / 1024 / 1024, 2) + self._invoker.services.logger.info( + f"Deleted {delete_all_result.deleted_count} {self._obj_class_name} files (freed {freed_space_in_mb}MB)" + ) def load(self, name: str) -> T: file_path = self._get_path(name) @@ -58,18 +70,14 @@ def _get_path(self, name: str) -> Path: def _new_name(self) -> str: return f"{self._obj_class_name}_{uuid_string()}" - def _delete_all(self) -> None: + def _delete_all(self) -> DeleteAllResult: """ Deletes all objects from disk. - Must be called after we have access to `self._invoker` (e.g. in `start()`). """ # We could try using a temporary directory here, but they aren't cleared in the event of a crash, so we'd have # to manually clear them on startup anyways. This is a bit simpler and more reliable. - if not self._invoker: - raise ValueError("Invoker is not set. Must call `start()` first.") - deleted_count = 0 freed_space = 0 for file in Path(self._output_dir).glob("*"): @@ -77,8 +85,4 @@ def _delete_all(self) -> None: freed_space += file.stat().st_size deleted_count += 1 file.unlink() - if deleted_count > 0: - freed_space_in_mb = round(freed_space / 1024 / 1024, 2) - self._invoker.services.logger.info( - f"Deleted {deleted_count} {self._obj_class_name} files (freed {freed_space_in_mb}MB)" - ) + return DeleteAllResult(deleted_count, freed_space) From 7a2b606001e4cc96436651e580440e8e4cda9d76 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 8 Feb 2024 00:20:10 +1100 Subject: [PATCH 046/100] tests: add object serializer tests These test both object serializer and its forward cache implementation. --- .../test_object_serializer_ephemeral_disk.py | 148 ++++++++++++++++++ 1 file changed, 148 insertions(+) create mode 100644 tests/test_object_serializer_ephemeral_disk.py diff --git a/tests/test_object_serializer_ephemeral_disk.py b/tests/test_object_serializer_ephemeral_disk.py new file mode 100644 index 00000000000..fffa65304f6 --- /dev/null +++ b/tests/test_object_serializer_ephemeral_disk.py @@ -0,0 +1,148 @@ +from dataclasses import dataclass +from pathlib import Path + +import pytest +import torch + +from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError +from invokeai.app.services.object_serializer.object_serializer_ephemeral_disk import ObjectSerializerEphemeralDisk +from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache + + +@dataclass +class MockDataclass: + foo: str + + +@pytest.fixture +def obj_serializer(tmp_path: Path): + return ObjectSerializerEphemeralDisk[MockDataclass](tmp_path) + + +@pytest.fixture +def fwd_cache(tmp_path: Path): + return ObjectSerializerForwardCache(ObjectSerializerEphemeralDisk[MockDataclass](tmp_path), max_cache_size=2) + + +def test_obj_serializer_ephemeral_disk_initializes(tmp_path: Path): + obj_serializer = ObjectSerializerEphemeralDisk[MockDataclass](tmp_path) + assert obj_serializer._output_dir == tmp_path + + +def test_obj_serializer_ephemeral_disk_saves(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): + obj_1 = MockDataclass(foo="bar") + obj_1_name = obj_serializer.save(obj_1) + assert Path(obj_serializer._output_dir, obj_1_name).exists() + + obj_2 = MockDataclass(foo="baz") + obj_2_name = obj_serializer.save(obj_2) + assert Path(obj_serializer._output_dir, obj_2_name).exists() + + +def test_obj_serializer_ephemeral_disk_loads(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): + obj_1 = MockDataclass(foo="bar") + obj_1_name = obj_serializer.save(obj_1) + assert obj_serializer.load(obj_1_name).foo == "bar" + + obj_2 = MockDataclass(foo="baz") + obj_2_name = obj_serializer.save(obj_2) + assert obj_serializer.load(obj_2_name).foo == "baz" + + with pytest.raises(ObjectNotFoundError): + obj_serializer.load("nonexistent_object_name") + + +def test_obj_serializer_ephemeral_disk_deletes(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): + obj_1 = MockDataclass(foo="bar") + obj_1_name = obj_serializer.save(obj_1) + + obj_2 = MockDataclass(foo="bar") + obj_2_name = obj_serializer.save(obj_2) + + obj_serializer.delete(obj_1_name) + assert not Path(obj_serializer._output_dir, obj_1_name).exists() + assert Path(obj_serializer._output_dir, obj_2_name).exists() + + +def test_obj_serializer_ephemeral_disk_deletes_all(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): + obj_1 = MockDataclass(foo="bar") + obj_1_name = obj_serializer.save(obj_1) + + obj_2 = MockDataclass(foo="bar") + obj_2_name = obj_serializer.save(obj_2) + + delete_all_result = obj_serializer._delete_all() + + assert not Path(obj_serializer._output_dir, obj_1_name).exists() + assert not Path(obj_serializer._output_dir, obj_2_name).exists() + assert delete_all_result.deleted_count == 2 + + +def test_obj_serializer_ephemeral_disk_different_types(tmp_path: Path): + obj_serializer = ObjectSerializerEphemeralDisk[MockDataclass](tmp_path) + + obj_1 = MockDataclass(foo="bar") + obj_1_name = obj_serializer.save(obj_1) + obj_1_loaded = obj_serializer.load(obj_1_name) + assert isinstance(obj_1_loaded, MockDataclass) + assert obj_1_loaded.foo == "bar" + assert obj_1_name.startswith("MockDataclass_") + + obj_serializer = ObjectSerializerEphemeralDisk[int](tmp_path) + obj_2_name = obj_serializer.save(9001) + assert obj_serializer.load(obj_2_name) == 9001 + assert obj_2_name.startswith("int_") + + obj_serializer = ObjectSerializerEphemeralDisk[str](tmp_path) + obj_3_name = obj_serializer.save("foo") + assert obj_serializer.load(obj_3_name) == "foo" + assert obj_3_name.startswith("str_") + + obj_serializer = ObjectSerializerEphemeralDisk[torch.Tensor](tmp_path) + obj_4_name = obj_serializer.save(torch.tensor([1, 2, 3])) + obj_4_loaded = obj_serializer.load(obj_4_name) + assert isinstance(obj_4_loaded, torch.Tensor) + assert torch.equal(obj_4_loaded, torch.tensor([1, 2, 3])) + assert obj_4_name.startswith("Tensor_") + + +def test_obj_serializer_fwd_cache_initializes(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): + fwd_cache = ObjectSerializerForwardCache(obj_serializer) + assert fwd_cache._underlying_storage == obj_serializer + + +def test_obj_serializer_fwd_cache_saves_and_loads(fwd_cache: ObjectSerializerForwardCache[MockDataclass]): + obj = MockDataclass(foo="bar") + obj_name = fwd_cache.save(obj) + obj_loaded = fwd_cache.load(obj_name) + obj_underlying = fwd_cache._underlying_storage.load(obj_name) + assert obj_loaded == obj_underlying + assert obj_loaded.foo == "bar" + + +def test_obj_serializer_fwd_cache_respects_cache_size(fwd_cache: ObjectSerializerForwardCache[MockDataclass]): + obj_1 = MockDataclass(foo="bar") + obj_1_name = fwd_cache.save(obj_1) + obj_2 = MockDataclass(foo="baz") + obj_2_name = fwd_cache.save(obj_2) + obj_3 = MockDataclass(foo="qux") + obj_3_name = fwd_cache.save(obj_3) + assert obj_1_name not in fwd_cache._cache + assert obj_2_name in fwd_cache._cache + assert obj_3_name in fwd_cache._cache + # apparently qsize is "not reliable"? + assert fwd_cache._cache_ids.qsize() == 2 + + +def test_obj_serializer_fwd_cache_calls_delete_callback(fwd_cache: ObjectSerializerForwardCache[MockDataclass]): + called_name = None + obj_1 = MockDataclass(foo="bar") + + def on_deleted(name: str): + nonlocal called_name + called_name = name + + fwd_cache.on_deleted(on_deleted) + obj_1_name = fwd_cache.save(obj_1) + fwd_cache.delete(obj_1_name) + assert called_name == obj_1_name From 1f27ddc07d8bdd2433857b898f931fbd84e0c153 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 8 Feb 2024 00:23:47 +1100 Subject: [PATCH 047/100] tidy(nodes): minor spelling correction --- invokeai/app/services/shared/invocation_context.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 8c5a821fd0f..828d3d84904 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -223,8 +223,8 @@ def save(self, tensor: Tensor) -> str: :param tensor: The tensor to save. """ - tensor_id = self._services.tensors.save(obj=tensor) - return tensor_id + name = self._services.tensors.save(obj=tensor) + return name def load(self, name: str) -> Tensor: """ @@ -243,8 +243,8 @@ def save(self, conditioning_data: ConditioningFieldData) -> str: :param conditioning_context_data: The conditioning data to save. """ - conditioning_id = self._services.conditioning.save(obj=conditioning_data) - return conditioning_id + name = self._services.conditioning.save(obj=conditioning_data) + return name def load(self, name: str) -> ConditioningFieldData: """ From bcb85e100db45afea6441fe4ca617bcf4372572a Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 8 Feb 2024 00:36:53 +1100 Subject: [PATCH 048/100] tests: fix broken tests --- .../object_serializer_ephemeral_disk.py | 9 ++++++--- .../object_serializer_forward_cache.py | 10 ++++++---- tests/aa_nodes/test_graph_execution_state.py | 5 +++-- tests/aa_nodes/test_invoker.py | 3 ++- 4 files changed, 17 insertions(+), 10 deletions(-) diff --git a/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py b/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py index 9545d1714d7..880848a1425 100644 --- a/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py +++ b/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py @@ -1,15 +1,18 @@ import typing from dataclasses import dataclass from pathlib import Path -from typing import Optional, TypeVar +from typing import TYPE_CHECKING, Optional, TypeVar import torch -from invokeai.app.services.invoker import Invoker from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError from invokeai.app.util.misc import uuid_string +if TYPE_CHECKING: + from invokeai.app.services.invoker import Invoker + + T = TypeVar("T") @@ -31,7 +34,7 @@ def __init__(self, output_dir: Path): self._output_dir.mkdir(parents=True, exist_ok=True) self.__obj_class_name: Optional[str] = None - def start(self, invoker: Invoker) -> None: + def start(self, invoker: "Invoker") -> None: self._invoker = invoker delete_all_result = self._delete_all() if delete_all_result.deleted_count > 0: diff --git a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py index 2a4ecdd844b..c8ca13982c1 100644 --- a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py +++ b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py @@ -1,11 +1,13 @@ from queue import Queue -from typing import Optional, TypeVar +from typing import TYPE_CHECKING, Optional, TypeVar -from invokeai.app.services.invoker import Invoker from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase T = TypeVar("T") +if TYPE_CHECKING: + from invokeai.app.services.invoker import Invoker + class ObjectSerializerForwardCache(ObjectSerializerBase[T]): """Provides a simple forward cache for an underlying storage. The cache is LRU and has a maximum size.""" @@ -17,13 +19,13 @@ def __init__(self, underlying_storage: ObjectSerializerBase[T], max_cache_size: self._cache_ids = Queue[str]() self._max_cache_size = max_cache_size - def start(self, invoker: Invoker) -> None: + def start(self, invoker: "Invoker") -> None: self._invoker = invoker start_op = getattr(self._underlying_storage, "start", None) if callable(start_op): start_op(invoker) - def stop(self, invoker: Invoker) -> None: + def stop(self, invoker: "Invoker") -> None: self._invoker = invoker stop_op = getattr(self._underlying_storage, "stop", None) if callable(stop_op): diff --git a/tests/aa_nodes/test_graph_execution_state.py b/tests/aa_nodes/test_graph_execution_state.py index aba7c5694f3..27d2d2230a3 100644 --- a/tests/aa_nodes/test_graph_execution_state.py +++ b/tests/aa_nodes/test_graph_execution_state.py @@ -60,7 +60,6 @@ def mock_services() -> InvocationServices: image_records=None, # type: ignore images=None, # type: ignore invocation_cache=MemoryInvocationCache(max_cache_size=0), - latents=None, # type: ignore logger=logging, # type: ignore model_manager=None, # type: ignore model_records=None, # type: ignore @@ -74,6 +73,8 @@ def mock_services() -> InvocationServices: session_queue=None, # type: ignore urls=None, # type: ignore workflow_records=None, # type: ignore + tensors=None, + conditioning=None, ) @@ -89,7 +90,7 @@ def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[B config=None, context_data=None, images=None, - latents=None, + tensors=None, logger=None, models=None, util=None, diff --git a/tests/aa_nodes/test_invoker.py b/tests/aa_nodes/test_invoker.py index 2ae4eab58a0..437ea0f00d3 100644 --- a/tests/aa_nodes/test_invoker.py +++ b/tests/aa_nodes/test_invoker.py @@ -63,7 +63,6 @@ def mock_services() -> InvocationServices: image_records=None, # type: ignore images=None, # type: ignore invocation_cache=MemoryInvocationCache(max_cache_size=0), - latents=None, # type: ignore logger=logging, # type: ignore model_manager=None, # type: ignore model_records=None, # type: ignore @@ -77,6 +76,8 @@ def mock_services() -> InvocationServices: session_queue=None, # type: ignore urls=None, # type: ignore workflow_records=None, # type: ignore + tensors=None, + conditioning=None, ) From bc5f356390c70372cc226ba097c4831c0ae77993 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 8 Feb 2024 07:55:36 +1100 Subject: [PATCH 049/100] feat(nodes): use LATENT_SCALE_FACTOR const in tensor output builders --- invokeai/app/invocations/noise.py | 5 +++-- invokeai/app/invocations/primitives.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index 74b3d6e4cb1..4093030388b 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -5,6 +5,7 @@ from pydantic import field_validator from invokeai.app.invocations.fields import FieldDescriptions, InputField, LatentsField, OutputField +from invokeai.app.invocations.latent import LATENT_SCALE_FACTOR from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.misc import SEED_MAX @@ -70,8 +71,8 @@ class NoiseOutput(BaseInvocationOutput): def build(cls, latents_name: str, latents: torch.Tensor, seed: int) -> "NoiseOutput": return cls( noise=LatentsField(latents_name=latents_name, seed=seed), - width=latents.size()[3] * 8, - height=latents.size()[2] * 8, + width=latents.size()[3] * LATENT_SCALE_FACTOR, + height=latents.size()[2] * LATENT_SCALE_FACTOR, ) diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index d0f95c92d02..2a9cb8cf9bf 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -16,6 +16,7 @@ OutputField, UIComponent, ) +from invokeai.app.invocations.latent import LATENT_SCALE_FACTOR from invokeai.app.services.images.images_common import ImageDTO from invokeai.app.services.shared.invocation_context import InvocationContext @@ -321,8 +322,8 @@ class LatentsOutput(BaseInvocationOutput): def build(cls, latents_name: str, latents: torch.Tensor, seed: Optional[int] = None) -> "LatentsOutput": return cls( latents=LatentsField(latents_name=latents_name, seed=seed), - width=latents.size()[3] * 8, - height=latents.size()[2] * 8, + width=latents.size()[3] * LATENT_SCALE_FACTOR, + height=latents.size()[2] * LATENT_SCALE_FACTOR, ) From 8892df1d97919c442b21ecbb28961320855ca4bd Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 8 Feb 2024 08:14:04 +1100 Subject: [PATCH 050/100] Revert "feat(nodes): use LATENT_SCALE_FACTOR const in tensor output builders" This reverts commit ef18fc546560277302f3886e456da9a47e8edce0. --- invokeai/app/invocations/noise.py | 5 ++--- invokeai/app/invocations/primitives.py | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index 4093030388b..74b3d6e4cb1 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -5,7 +5,6 @@ from pydantic import field_validator from invokeai.app.invocations.fields import FieldDescriptions, InputField, LatentsField, OutputField -from invokeai.app.invocations.latent import LATENT_SCALE_FACTOR from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.misc import SEED_MAX @@ -71,8 +70,8 @@ class NoiseOutput(BaseInvocationOutput): def build(cls, latents_name: str, latents: torch.Tensor, seed: int) -> "NoiseOutput": return cls( noise=LatentsField(latents_name=latents_name, seed=seed), - width=latents.size()[3] * LATENT_SCALE_FACTOR, - height=latents.size()[2] * LATENT_SCALE_FACTOR, + width=latents.size()[3] * 8, + height=latents.size()[2] * 8, ) diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index 2a9cb8cf9bf..d0f95c92d02 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -16,7 +16,6 @@ OutputField, UIComponent, ) -from invokeai.app.invocations.latent import LATENT_SCALE_FACTOR from invokeai.app.services.images.images_common import ImageDTO from invokeai.app.services.shared.invocation_context import InvocationContext @@ -322,8 +321,8 @@ class LatentsOutput(BaseInvocationOutput): def build(cls, latents_name: str, latents: torch.Tensor, seed: Optional[int] = None) -> "LatentsOutput": return cls( latents=LatentsField(latents_name=latents_name, seed=seed), - width=latents.size()[3] * LATENT_SCALE_FACTOR, - height=latents.size()[2] * LATENT_SCALE_FACTOR, + width=latents.size()[3] * 8, + height=latents.size()[2] * 8, ) From cb0b389b4b3fbdc45f835c6275c8c68ac963f9ff Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 8 Feb 2024 10:57:01 +1100 Subject: [PATCH 051/100] tidy(nodes): clarify comment --- invokeai/app/services/shared/invocation_context.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 828d3d84904..3d06cf92725 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -279,15 +279,8 @@ def load( :param submodel: The submodel of the model to get. """ - # During this call, the model manager emits events with model loading status. The model - # manager itself has access to the events services, but does not have access to the - # required metadata for the events. - # - # For example, it needs access to the node's ID so that the events can be associated - # with the execution of a specific node. - # - # While this is available within the node, it's tedious to need to pass it in on every - # call. We can avoid that by wrapping the method here. + # The model manager emits events as it loads the model. It needs the context data to build + # the event payloads. return self._services.model_manager.get_model( model_name, base_model, model_type, submodel, context_data=self._context_data From ed772a71079903315b378a8d82f941effc9f7975 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 8 Feb 2024 11:05:33 +1100 Subject: [PATCH 052/100] fix(nodes): use `metadata`/`board_id` if provided by user, overriding `WithMetadata`/`WithBoard`-provided values --- .../app/services/shared/invocation_context.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 3d06cf92725..1ca44b78625 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -167,16 +167,19 @@ def save( **Use this only if you want to override or provide metadata manually!** """ - # If the invocation inherits metadata, use that. Else, use the metadata passed in. - metadata_ = ( - self._context_data.invocation.metadata - if isinstance(self._context_data.invocation, WithMetadata) - else metadata - ) - - # If the invocation inherits WithBoard, use that. Else, use the board_id passed in. - board_ = self._context_data.invocation.board if isinstance(self._context_data.invocation, WithBoard) else None - board_id_ = board_.board_id if board_ is not None else board_id + # If `metadata` is provided directly, use that. Else, use the metadata provided by `WithMetadata`, falling back to None. + metadata_ = None + if metadata: + metadata_ = metadata + elif isinstance(self._context_data.invocation, WithMetadata): + metadata_ = self._context_data.invocation.metadata + + # If `board_id` is provided directly, use that. Else, use the board provided by `WithBoard`, falling back to None. + board_id_ = None + if board_id: + board_id_ = board_id + elif isinstance(self._context_data.invocation, WithBoard) and self._context_data.invocation.board: + board_id_ = self._context_data.invocation.board.board_id return self._services.images.create( image=image, From c58f8c32694f6eef3322922fac94bd0f8809a87c Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 8 Feb 2024 16:09:59 +1100 Subject: [PATCH 053/100] feat(nodes): make delete on startup configurable for obj serializer - The default is to not delete on startup - feels safer. - The two services using this class _do_ delete on startup. - The class has "ephemeral" removed from its name. - Tests & app updated for this change. --- invokeai/app/api/dependencies.py | 8 ++- ...eral_disk.py => object_serializer_disk.py} | 22 ++++--- ...disk.py => test_object_serializer_disk.py} | 64 ++++++++++++++----- 3 files changed, 67 insertions(+), 27 deletions(-) rename invokeai/app/services/object_serializer/{object_serializer_ephemeral_disk.py => object_serializer_disk.py} (77%) rename tests/{test_object_serializer_ephemeral_disk.py => test_object_serializer_disk.py} (65%) diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 0c80494616f..2acb961aa7a 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -5,7 +5,7 @@ import torch from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory -from invokeai.app.services.object_serializer.object_serializer_ephemeral_disk import ObjectSerializerEphemeralDisk +from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache from invokeai.app.services.shared.sqlite.sqlite_util import init_db from invokeai.backend.model_manager.metadata import ModelMetadataStore @@ -90,9 +90,11 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger image_records = SqliteImageRecordStorage(db=db) images = ImageService() invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) - tensors = ObjectSerializerForwardCache(ObjectSerializerEphemeralDisk[torch.Tensor](output_folder / "tensors")) + tensors = ObjectSerializerForwardCache( + ObjectSerializerDisk[torch.Tensor](output_folder / "tensors", delete_on_startup=True) + ) conditioning = ObjectSerializerForwardCache( - ObjectSerializerEphemeralDisk[ConditioningFieldData](output_folder / "conditioning") + ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", delete_on_startup=True) ) model_manager = ModelManagerService(config, logger) model_record_service = ModelRecordServiceSQL(db=db) diff --git a/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py b/invokeai/app/services/object_serializer/object_serializer_disk.py similarity index 77% rename from invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py rename to invokeai/app/services/object_serializer/object_serializer_disk.py index 880848a1425..174ff15192d 100644 --- a/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py +++ b/invokeai/app/services/object_serializer/object_serializer_disk.py @@ -22,26 +22,30 @@ class DeleteAllResult: freed_space_bytes: float -class ObjectSerializerEphemeralDisk(ObjectSerializerBase[T]): - """Provides a disk-backed ephemeral storage for arbitrary python objects. The storage is cleared at startup. +class ObjectSerializerDisk(ObjectSerializerBase[T]): + """Provides a disk-backed storage for arbitrary python objects. :param output_folder: The folder where the objects will be stored + :param delete_on_startup: If True, all objects in the output folder will be deleted on startup """ - def __init__(self, output_dir: Path): + def __init__(self, output_dir: Path, delete_on_startup: bool = False): super().__init__() self._output_dir = output_dir self._output_dir.mkdir(parents=True, exist_ok=True) + self._delete_on_startup = delete_on_startup self.__obj_class_name: Optional[str] = None def start(self, invoker: "Invoker") -> None: self._invoker = invoker - delete_all_result = self._delete_all() - if delete_all_result.deleted_count > 0: - freed_space_in_mb = round(delete_all_result.freed_space_bytes / 1024 / 1024, 2) - self._invoker.services.logger.info( - f"Deleted {delete_all_result.deleted_count} {self._obj_class_name} files (freed {freed_space_in_mb}MB)" - ) + + if self._delete_on_startup: + delete_all_result = self._delete_all() + if delete_all_result.deleted_count > 0: + freed_space_in_mb = round(delete_all_result.freed_space_bytes / 1024 / 1024, 2) + self._invoker.services.logger.info( + f"Deleted {delete_all_result.deleted_count} {self._obj_class_name} files (freed {freed_space_in_mb}MB)" + ) def load(self, name: str) -> T: file_path = self._get_path(name) diff --git a/tests/test_object_serializer_ephemeral_disk.py b/tests/test_object_serializer_disk.py similarity index 65% rename from tests/test_object_serializer_ephemeral_disk.py rename to tests/test_object_serializer_disk.py index fffa65304f6..5ce1e579013 100644 --- a/tests/test_object_serializer_ephemeral_disk.py +++ b/tests/test_object_serializer_disk.py @@ -1,11 +1,14 @@ from dataclasses import dataclass +from logging import Logger from pathlib import Path +from unittest.mock import Mock import pytest import torch +from invokeai.app.services.invoker import Invoker from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError -from invokeai.app.services.object_serializer.object_serializer_ephemeral_disk import ObjectSerializerEphemeralDisk +from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache @@ -14,22 +17,31 @@ class MockDataclass: foo: str +def count_files(path: Path): + return len(list(path.iterdir())) + + @pytest.fixture def obj_serializer(tmp_path: Path): - return ObjectSerializerEphemeralDisk[MockDataclass](tmp_path) + return ObjectSerializerDisk[MockDataclass](tmp_path) @pytest.fixture def fwd_cache(tmp_path: Path): - return ObjectSerializerForwardCache(ObjectSerializerEphemeralDisk[MockDataclass](tmp_path), max_cache_size=2) + return ObjectSerializerForwardCache(ObjectSerializerDisk[MockDataclass](tmp_path), max_cache_size=2) + + +@pytest.fixture +def mock_invoker_with_logger(): + return Mock(Invoker, services=Mock(logger=Mock(Logger))) -def test_obj_serializer_ephemeral_disk_initializes(tmp_path: Path): - obj_serializer = ObjectSerializerEphemeralDisk[MockDataclass](tmp_path) +def test_obj_serializer_disk_initializes(tmp_path: Path): + obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path) assert obj_serializer._output_dir == tmp_path -def test_obj_serializer_ephemeral_disk_saves(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): +def test_obj_serializer_disk_saves(obj_serializer: ObjectSerializerDisk[MockDataclass]): obj_1 = MockDataclass(foo="bar") obj_1_name = obj_serializer.save(obj_1) assert Path(obj_serializer._output_dir, obj_1_name).exists() @@ -39,7 +51,7 @@ def test_obj_serializer_ephemeral_disk_saves(obj_serializer: ObjectSerializerEph assert Path(obj_serializer._output_dir, obj_2_name).exists() -def test_obj_serializer_ephemeral_disk_loads(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): +def test_obj_serializer_disk_loads(obj_serializer: ObjectSerializerDisk[MockDataclass]): obj_1 = MockDataclass(foo="bar") obj_1_name = obj_serializer.save(obj_1) assert obj_serializer.load(obj_1_name).foo == "bar" @@ -52,7 +64,7 @@ def test_obj_serializer_ephemeral_disk_loads(obj_serializer: ObjectSerializerEph obj_serializer.load("nonexistent_object_name") -def test_obj_serializer_ephemeral_disk_deletes(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): +def test_obj_serializer_disk_deletes(obj_serializer: ObjectSerializerDisk[MockDataclass]): obj_1 = MockDataclass(foo="bar") obj_1_name = obj_serializer.save(obj_1) @@ -64,7 +76,7 @@ def test_obj_serializer_ephemeral_disk_deletes(obj_serializer: ObjectSerializerE assert Path(obj_serializer._output_dir, obj_2_name).exists() -def test_obj_serializer_ephemeral_disk_deletes_all(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): +def test_obj_serializer_disk_deletes_all(obj_serializer: ObjectSerializerDisk[MockDataclass]): obj_1 = MockDataclass(foo="bar") obj_1_name = obj_serializer.save(obj_1) @@ -78,8 +90,30 @@ def test_obj_serializer_ephemeral_disk_deletes_all(obj_serializer: ObjectSeriali assert delete_all_result.deleted_count == 2 -def test_obj_serializer_ephemeral_disk_different_types(tmp_path: Path): - obj_serializer = ObjectSerializerEphemeralDisk[MockDataclass](tmp_path) +def test_obj_serializer_disk_default_no_delete_on_startup(tmp_path: Path, mock_invoker_with_logger: Invoker): + obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path) + assert obj_serializer._delete_on_startup is False + + obj_1 = MockDataclass(foo="bar") + obj_1_name = obj_serializer.save(obj_1) + + obj_serializer.start(mock_invoker_with_logger) + assert Path(tmp_path, obj_1_name).exists() + + +def test_obj_serializer_disk_delete_on_startup(tmp_path: Path, mock_invoker_with_logger: Invoker): + obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, delete_on_startup=True) + assert obj_serializer._delete_on_startup is True + + obj_1 = MockDataclass(foo="bar") + obj_1_name = obj_serializer.save(obj_1) + + obj_serializer.start(mock_invoker_with_logger) + assert not Path(tmp_path, obj_1_name).exists() + + +def test_obj_serializer_disk_different_types(tmp_path: Path): + obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path) obj_1 = MockDataclass(foo="bar") obj_1_name = obj_serializer.save(obj_1) @@ -88,17 +122,17 @@ def test_obj_serializer_ephemeral_disk_different_types(tmp_path: Path): assert obj_1_loaded.foo == "bar" assert obj_1_name.startswith("MockDataclass_") - obj_serializer = ObjectSerializerEphemeralDisk[int](tmp_path) + obj_serializer = ObjectSerializerDisk[int](tmp_path) obj_2_name = obj_serializer.save(9001) assert obj_serializer.load(obj_2_name) == 9001 assert obj_2_name.startswith("int_") - obj_serializer = ObjectSerializerEphemeralDisk[str](tmp_path) + obj_serializer = ObjectSerializerDisk[str](tmp_path) obj_3_name = obj_serializer.save("foo") assert obj_serializer.load(obj_3_name) == "foo" assert obj_3_name.startswith("str_") - obj_serializer = ObjectSerializerEphemeralDisk[torch.Tensor](tmp_path) + obj_serializer = ObjectSerializerDisk[torch.Tensor](tmp_path) obj_4_name = obj_serializer.save(torch.tensor([1, 2, 3])) obj_4_loaded = obj_serializer.load(obj_4_name) assert isinstance(obj_4_loaded, torch.Tensor) @@ -106,7 +140,7 @@ def test_obj_serializer_ephemeral_disk_different_types(tmp_path: Path): assert obj_4_name.startswith("Tensor_") -def test_obj_serializer_fwd_cache_initializes(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): +def test_obj_serializer_fwd_cache_initializes(obj_serializer: ObjectSerializerDisk[MockDataclass]): fwd_cache = ObjectSerializerForwardCache(obj_serializer) assert fwd_cache._underlying_storage == obj_serializer From ff249a231522e7a43417a90d48fa6ab9320b7fdd Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 10 Feb 2024 09:41:23 +1100 Subject: [PATCH 054/100] tidy(nodes): do not store unnecessarily store invoker --- .../app/services/object_serializer/object_serializer_disk.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/invokeai/app/services/object_serializer/object_serializer_disk.py b/invokeai/app/services/object_serializer/object_serializer_disk.py index 174ff15192d..b3827e16a92 100644 --- a/invokeai/app/services/object_serializer/object_serializer_disk.py +++ b/invokeai/app/services/object_serializer/object_serializer_disk.py @@ -37,13 +37,11 @@ def __init__(self, output_dir: Path, delete_on_startup: bool = False): self.__obj_class_name: Optional[str] = None def start(self, invoker: "Invoker") -> None: - self._invoker = invoker - if self._delete_on_startup: delete_all_result = self._delete_all() if delete_all_result.deleted_count > 0: freed_space_in_mb = round(delete_all_result.freed_space_bytes / 1024 / 1024, 2) - self._invoker.services.logger.info( + invoker.services.logger.info( f"Deleted {delete_all_result.deleted_count} {self._obj_class_name} files (freed {freed_space_in_mb}MB)" ) From c1e5cd589364913ffd382933bf06bce26e762f49 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 10 Feb 2024 10:06:50 +1100 Subject: [PATCH 055/100] tidy(nodes): "latents" -> "obj" --- .../object_serializer/object_serializer_forward_cache.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py index c8ca13982c1..812731f456a 100644 --- a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py +++ b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py @@ -36,9 +36,9 @@ def load(self, name: str) -> T: if cache_item is not None: return cache_item - latent = self._underlying_storage.load(name) - self._set_cache(name, latent) - return latent + obj = self._underlying_storage.load(name) + self._set_cache(name, obj) + return obj def save(self, obj: T) -> str: name = self._underlying_storage.save(obj) From 6bb2dda3f1b6c262d73d5c717d4e84bd45ac29e3 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 10 Feb 2024 10:10:48 +1100 Subject: [PATCH 056/100] chore(nodes): fix pyright ignore --- .../app/services/object_serializer/object_serializer_disk.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/app/services/object_serializer/object_serializer_disk.py b/invokeai/app/services/object_serializer/object_serializer_disk.py index b3827e16a92..06f86aa460c 100644 --- a/invokeai/app/services/object_serializer/object_serializer_disk.py +++ b/invokeai/app/services/object_serializer/object_serializer_disk.py @@ -66,7 +66,7 @@ def delete(self, name: str) -> None: def _obj_class_name(self) -> str: if not self.__obj_class_name: # `__orig_class__` is not available in the constructor for some technical, undoubtedly very pythonic reason - self.__obj_class_name = typing.get_args(self.__orig_class__)[0].__name__ # pyright: ignore [reportUnknownMemberType, reportGeneralTypeIssues] + self.__obj_class_name = typing.get_args(self.__orig_class__)[0].__name__ # pyright: ignore [reportUnknownMemberType, reportAttributeAccessIssue] return self.__obj_class_name def _get_path(self, name: str) -> Path: From a7207ed8cf2485842271790d0f266b626439816c Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 10 Feb 2024 10:11:31 +1100 Subject: [PATCH 057/100] chore(nodes): update ObjectSerializerForwardCache docstring --- .../object_serializer/object_serializer_forward_cache.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py index 812731f456a..b361259a4b1 100644 --- a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py +++ b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py @@ -10,7 +10,10 @@ class ObjectSerializerForwardCache(ObjectSerializerBase[T]): - """Provides a simple forward cache for an underlying storage. The cache is LRU and has a maximum size.""" + """ + Provides a LRU cache for an instance of `ObjectSerializerBase`. + Saving an object to the cache always writes through to the underlying storage. + """ def __init__(self, underlying_storage: ObjectSerializerBase[T], max_cache_size: int = 20): super().__init__() From 199ddd6623ea738f2d61d869fe88e5f5184cf1ee Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 10 Feb 2024 18:46:51 +1100 Subject: [PATCH 058/100] tests: test ObjectSerializerDisk class name extraction --- tests/test_object_serializer_disk.py | 29 +++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/tests/test_object_serializer_disk.py b/tests/test_object_serializer_disk.py index 5ce1e579013..2bc7e16937e 100644 --- a/tests/test_object_serializer_disk.py +++ b/tests/test_object_serializer_disk.py @@ -113,28 +113,31 @@ def test_obj_serializer_disk_delete_on_startup(tmp_path: Path, mock_invoker_with def test_obj_serializer_disk_different_types(tmp_path: Path): - obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path) - + obj_serializer_1 = ObjectSerializerDisk[MockDataclass](tmp_path) obj_1 = MockDataclass(foo="bar") - obj_1_name = obj_serializer.save(obj_1) - obj_1_loaded = obj_serializer.load(obj_1_name) + obj_1_name = obj_serializer_1.save(obj_1) + obj_1_loaded = obj_serializer_1.load(obj_1_name) + assert obj_serializer_1._obj_class_name == "MockDataclass" assert isinstance(obj_1_loaded, MockDataclass) assert obj_1_loaded.foo == "bar" assert obj_1_name.startswith("MockDataclass_") - obj_serializer = ObjectSerializerDisk[int](tmp_path) - obj_2_name = obj_serializer.save(9001) - assert obj_serializer.load(obj_2_name) == 9001 + obj_serializer_2 = ObjectSerializerDisk[int](tmp_path) + obj_2_name = obj_serializer_2.save(9001) + assert obj_serializer_2._obj_class_name == "int" + assert obj_serializer_2.load(obj_2_name) == 9001 assert obj_2_name.startswith("int_") - obj_serializer = ObjectSerializerDisk[str](tmp_path) - obj_3_name = obj_serializer.save("foo") - assert obj_serializer.load(obj_3_name) == "foo" + obj_serializer_3 = ObjectSerializerDisk[str](tmp_path) + obj_3_name = obj_serializer_3.save("foo") + assert obj_serializer_3._obj_class_name == "str" + assert obj_serializer_3.load(obj_3_name) == "foo" assert obj_3_name.startswith("str_") - obj_serializer = ObjectSerializerDisk[torch.Tensor](tmp_path) - obj_4_name = obj_serializer.save(torch.tensor([1, 2, 3])) - obj_4_loaded = obj_serializer.load(obj_4_name) + obj_serializer_4 = ObjectSerializerDisk[torch.Tensor](tmp_path) + obj_4_name = obj_serializer_4.save(torch.tensor([1, 2, 3])) + obj_4_loaded = obj_serializer_4.load(obj_4_name) + assert obj_serializer_4._obj_class_name == "Tensor" assert isinstance(obj_4_loaded, torch.Tensor) assert torch.equal(obj_4_loaded, torch.tensor([1, 2, 3])) assert obj_4_name.startswith("Tensor_") From b7ffd36cc6d81f832ae922fe62d7a029f10115c1 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 10 Feb 2024 19:11:28 +1100 Subject: [PATCH 059/100] feat(nodes): use TemporaryDirectory to handle ephemeral storage in ObjectSerializerDisk Replace `delete_on_startup: bool` & associated logic with `ephemeral: bool` and `TemporaryDirectory`. The temp dir is created inside of `output_dir`. For example, if `output_dir` is `invokeai/outputs/tensors/`, then the temp dir might be `invokeai/outputs/tensors/tmpvj35ht7b/`. The temp dir is cleaned up when the service is stopped, or when it is GC'd if not properly stopped. In the event of a catastrophic crash where the temp files are not cleaned up, the user can delete the tempdir themselves. This situation may not occur in normal use, but if you kill the process, python cannot clean up the temp dir itself. This includes running the app in a debugger and killing the debugger process - something I do relatively often. Tests updated. --- invokeai/app/api/dependencies.py | 4 +- .../object_serializer_disk.py | 56 ++++++++----------- tests/test_object_serializer_disk.py | 53 +++++++----------- 3 files changed, 46 insertions(+), 67 deletions(-) diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 2acb961aa7a..0f2a92b5c8e 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -91,10 +91,10 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger images = ImageService() invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) tensors = ObjectSerializerForwardCache( - ObjectSerializerDisk[torch.Tensor](output_folder / "tensors", delete_on_startup=True) + ObjectSerializerDisk[torch.Tensor](output_folder / "tensors", ephemeral=True) ) conditioning = ObjectSerializerForwardCache( - ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", delete_on_startup=True) + ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True) ) model_manager = ModelManagerService(config, logger) model_record_service = ModelRecordServiceSQL(db=db) diff --git a/invokeai/app/services/object_serializer/object_serializer_disk.py b/invokeai/app/services/object_serializer/object_serializer_disk.py index 06f86aa460c..935fec30605 100644 --- a/invokeai/app/services/object_serializer/object_serializer_disk.py +++ b/invokeai/app/services/object_serializer/object_serializer_disk.py @@ -1,3 +1,4 @@ +import tempfile import typing from dataclasses import dataclass from pathlib import Path @@ -23,28 +24,24 @@ class DeleteAllResult: class ObjectSerializerDisk(ObjectSerializerBase[T]): - """Provides a disk-backed storage for arbitrary python objects. + """Disk-backed storage for arbitrary python objects. Serialization is handled by `torch.save` and `torch.load`. - :param output_folder: The folder where the objects will be stored - :param delete_on_startup: If True, all objects in the output folder will be deleted on startup + :param output_dir: The folder where the serialized objects will be stored + :param ephemeral: If True, objects will be stored in a temporary directory inside the given output_dir and cleaned up on exit """ - def __init__(self, output_dir: Path, delete_on_startup: bool = False): + def __init__(self, output_dir: Path, ephemeral: bool = False): super().__init__() - self._output_dir = output_dir - self._output_dir.mkdir(parents=True, exist_ok=True) - self._delete_on_startup = delete_on_startup + self._ephemeral = ephemeral + self._base_output_dir = output_dir + self._base_output_dir.mkdir(parents=True, exist_ok=True) + # Must specify `ignore_cleanup_errors` to avoid fatal errors during cleanup on Windows + self._tempdir = ( + tempfile.TemporaryDirectory(dir=self._base_output_dir, ignore_cleanup_errors=True) if ephemeral else None + ) + self._output_dir = Path(self._tempdir.name) if self._tempdir else self._base_output_dir self.__obj_class_name: Optional[str] = None - def start(self, invoker: "Invoker") -> None: - if self._delete_on_startup: - delete_all_result = self._delete_all() - if delete_all_result.deleted_count > 0: - freed_space_in_mb = round(delete_all_result.freed_space_bytes / 1024 / 1024, 2) - invoker.services.logger.info( - f"Deleted {delete_all_result.deleted_count} {self._obj_class_name} files (freed {freed_space_in_mb}MB)" - ) - def load(self, name: str) -> T: file_path = self._get_path(name) try: @@ -75,19 +72,14 @@ def _get_path(self, name: str) -> Path: def _new_name(self) -> str: return f"{self._obj_class_name}_{uuid_string()}" - def _delete_all(self) -> DeleteAllResult: - """ - Deletes all objects from disk. - """ - - # We could try using a temporary directory here, but they aren't cleared in the event of a crash, so we'd have - # to manually clear them on startup anyways. This is a bit simpler and more reliable. - - deleted_count = 0 - freed_space = 0 - for file in Path(self._output_dir).glob("*"): - if file.is_file(): - freed_space += file.stat().st_size - deleted_count += 1 - file.unlink() - return DeleteAllResult(deleted_count, freed_space) + def _tempdir_cleanup(self) -> None: + """Calls `cleanup` on the temporary directory, if it exists.""" + if self._tempdir: + self._tempdir.cleanup() + + def __del__(self) -> None: + # In case the service is not properly stopped, clean up the temporary directory when the class instance is GC'd. + self._tempdir_cleanup() + + def stop(self, invoker: "Invoker") -> None: + self._tempdir_cleanup() diff --git a/tests/test_object_serializer_disk.py b/tests/test_object_serializer_disk.py index 2bc7e16937e..125534c5002 100644 --- a/tests/test_object_serializer_disk.py +++ b/tests/test_object_serializer_disk.py @@ -1,12 +1,10 @@ +import tempfile from dataclasses import dataclass -from logging import Logger from pathlib import Path -from unittest.mock import Mock import pytest import torch -from invokeai.app.services.invoker import Invoker from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache @@ -31,11 +29,6 @@ def fwd_cache(tmp_path: Path): return ObjectSerializerForwardCache(ObjectSerializerDisk[MockDataclass](tmp_path), max_cache_size=2) -@pytest.fixture -def mock_invoker_with_logger(): - return Mock(Invoker, services=Mock(logger=Mock(Logger))) - - def test_obj_serializer_disk_initializes(tmp_path: Path): obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path) assert obj_serializer._output_dir == tmp_path @@ -76,39 +69,33 @@ def test_obj_serializer_disk_deletes(obj_serializer: ObjectSerializerDisk[MockDa assert Path(obj_serializer._output_dir, obj_2_name).exists() -def test_obj_serializer_disk_deletes_all(obj_serializer: ObjectSerializerDisk[MockDataclass]): - obj_1 = MockDataclass(foo="bar") - obj_1_name = obj_serializer.save(obj_1) +def test_obj_serializer_ephemeral_creates_tempdir(tmp_path: Path): + obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True) + assert isinstance(obj_serializer._tempdir, tempfile.TemporaryDirectory) + assert obj_serializer._base_output_dir == tmp_path + assert obj_serializer._output_dir != tmp_path + assert obj_serializer._output_dir == Path(obj_serializer._tempdir.name) - obj_2 = MockDataclass(foo="bar") - obj_2_name = obj_serializer.save(obj_2) - delete_all_result = obj_serializer._delete_all() +def test_obj_serializer_ephemeral_deletes_tempdir(tmp_path: Path): + obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True) + tempdir_path = obj_serializer._output_dir + del obj_serializer + assert not tempdir_path.exists() - assert not Path(obj_serializer._output_dir, obj_1_name).exists() - assert not Path(obj_serializer._output_dir, obj_2_name).exists() - assert delete_all_result.deleted_count == 2 +def test_obj_serializer_ephemeral_deletes_tempdir_on_stop(tmp_path: Path): + obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True) + tempdir_path = obj_serializer._output_dir + obj_serializer.stop(None) # pyright: ignore [reportArgumentType] + assert not tempdir_path.exists() -def test_obj_serializer_disk_default_no_delete_on_startup(tmp_path: Path, mock_invoker_with_logger: Invoker): - obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path) - assert obj_serializer._delete_on_startup is False +def test_obj_serializer_ephemeral_writes_to_tempdir(tmp_path: Path): + obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True) obj_1 = MockDataclass(foo="bar") obj_1_name = obj_serializer.save(obj_1) - - obj_serializer.start(mock_invoker_with_logger) - assert Path(tmp_path, obj_1_name).exists() - - -def test_obj_serializer_disk_delete_on_startup(tmp_path: Path, mock_invoker_with_logger: Invoker): - obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, delete_on_startup=True) - assert obj_serializer._delete_on_startup is True - - obj_1 = MockDataclass(foo="bar") - obj_1_name = obj_serializer.save(obj_1) - - obj_serializer.start(mock_invoker_with_logger) + assert Path(obj_serializer._output_dir, obj_1_name).exists() assert not Path(tmp_path, obj_1_name).exists() From ba7b1b266507ce712bbf8f6209f1a50d888bf090 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 11 Feb 2024 08:52:07 +1100 Subject: [PATCH 060/100] feat(nodes): extract LATENT_SCALE_FACTOR to constants.py --- invokeai/app/invocations/constants.py | 7 +++++++ invokeai/app/invocations/latent.py | 7 +------ invokeai/backend/tiles/tiles.py | 2 +- 3 files changed, 9 insertions(+), 7 deletions(-) create mode 100644 invokeai/app/invocations/constants.py diff --git a/invokeai/app/invocations/constants.py b/invokeai/app/invocations/constants.py new file mode 100644 index 00000000000..95b16f0d057 --- /dev/null +++ b/invokeai/app/invocations/constants.py @@ -0,0 +1,7 @@ +LATENT_SCALE_FACTOR = 8 +""" +HACK: Many nodes are currently hard-coded to use a fixed latent scale factor of 8. This is fragile, and will need to +be addressed if future models use a different latent scale factor. Also, note that there may be places where the scale +factor is hard-coded to a literal '8' rather than using this constant. +The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1. +""" diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 4137ab6e2f6..fedfc38402d 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -23,6 +23,7 @@ from pydantic import field_validator from torchvision.transforms.functional import resize as tv_resize +from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR from invokeai.app.invocations.fields import ( ConditioningField, DenoiseMaskField, @@ -79,12 +80,6 @@ SAMPLER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())] -# HACK: Many nodes are currently hard-coded to use a fixed latent scale factor of 8. This is fragile, and will need to -# be addressed if future models use a different latent scale factor. Also, note that there may be places where the scale -# factor is hard-coded to a literal '8' rather than using this constant. -# The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1. -LATENT_SCALE_FACTOR = 8 - @invocation_output("scheduler_output") class SchedulerOutput(BaseInvocationOutput): diff --git a/invokeai/backend/tiles/tiles.py b/invokeai/backend/tiles/tiles.py index 3c400fc87ce..2757dadba20 100644 --- a/invokeai/backend/tiles/tiles.py +++ b/invokeai/backend/tiles/tiles.py @@ -3,7 +3,7 @@ import numpy as np -from invokeai.app.invocations.latent import LATENT_SCALE_FACTOR +from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR from invokeai.backend.tiles.utils import TBLR, Tile, paste, seam_blend From 2005411f7e799e63a814995bea3cd32af5c77792 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 11 Feb 2024 08:54:01 +1100 Subject: [PATCH 061/100] feat(nodes): use LATENT_SCALE_FACTOR in primitives.py, noise.py - LatentsOutput.build - NoiseOutput.build - Noise.width, Noise.height multiple_of --- invokeai/app/invocations/noise.py | 9 +++++---- invokeai/app/invocations/primitives.py | 5 +++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index 74b3d6e4cb1..335d3df292e 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -4,6 +4,7 @@ import torch from pydantic import field_validator +from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR from invokeai.app.invocations.fields import FieldDescriptions, InputField, LatentsField, OutputField from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.misc import SEED_MAX @@ -70,8 +71,8 @@ class NoiseOutput(BaseInvocationOutput): def build(cls, latents_name: str, latents: torch.Tensor, seed: int) -> "NoiseOutput": return cls( noise=LatentsField(latents_name=latents_name, seed=seed), - width=latents.size()[3] * 8, - height=latents.size()[2] * 8, + width=latents.size()[3] * LATENT_SCALE_FACTOR, + height=latents.size()[2] * LATENT_SCALE_FACTOR, ) @@ -93,13 +94,13 @@ class NoiseInvocation(BaseInvocation): ) width: int = InputField( default=512, - multiple_of=8, + multiple_of=LATENT_SCALE_FACTOR, gt=0, description=FieldDescriptions.width, ) height: int = InputField( default=512, - multiple_of=8, + multiple_of=LATENT_SCALE_FACTOR, gt=0, description=FieldDescriptions.height, ) diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index d0f95c92d02..43422134829 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -4,6 +4,7 @@ import torch +from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR from invokeai.app.invocations.fields import ( ColorField, ConditioningField, @@ -321,8 +322,8 @@ class LatentsOutput(BaseInvocationOutput): def build(cls, latents_name: str, latents: torch.Tensor, seed: Optional[int] = None) -> "LatentsOutput": return cls( latents=LatentsField(latents_name=latents_name, seed=seed), - width=latents.size()[3] * 8, - height=latents.size()[2] * 8, + width=latents.size()[3] * LATENT_SCALE_FACTOR, + height=latents.size()[2] * LATENT_SCALE_FACTOR, ) From 083a4f3faa2e4c22a2d1949141ba2425c622c511 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 11 Feb 2024 09:27:57 +1100 Subject: [PATCH 062/100] chore(backend): rename `ModelInfo` -> `LoadedModelInfo` We have two different classes named `ModelInfo` which might need to be used by API consumers. We need to export both but have to deal with this naming collision. The `ModelInfo` I've renamed here is the one that is returned when a model is loaded. It's the object least likely to be used by API consumers. --- invokeai/app/services/events/events_base.py | 10 +++++----- .../services/model_manager/model_manager_base.py | 4 ++-- .../model_manager/model_manager_default.py | 16 ++++++++-------- .../app/services/shared/invocation_context.py | 7 ++++--- invokeai/backend/__init__.py | 9 ++++++++- invokeai/backend/model_management/__init__.py | 2 +- .../backend/model_management/model_manager.py | 6 +++--- invokeai/backend/util/test_utils.py | 10 +++++----- invokeai/invocation_api/__init__.py | 4 ++-- 9 files changed, 38 insertions(+), 30 deletions(-) diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index ad08ae03956..6b441efc2bf 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -11,7 +11,7 @@ SessionQueueStatus, ) from invokeai.app.util.misc import get_timestamp -from invokeai.backend.model_management.model_manager import ModelInfo +from invokeai.backend.model_management.model_manager import LoadedModelInfo from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType @@ -201,7 +201,7 @@ def emit_model_load_completed( base_model: BaseModelType, model_type: ModelType, submodel: SubModelType, - model_info: ModelInfo, + loaded_model_info: LoadedModelInfo, ) -> None: """Emitted when a model is correctly loaded (returns model info)""" self.__emit_queue_event( @@ -215,9 +215,9 @@ def emit_model_load_completed( "base_model": base_model, "model_type": model_type, "submodel": submodel, - "hash": model_info.hash, - "location": str(model_info.location), - "precision": str(model_info.precision), + "hash": loaded_model_info.hash, + "location": str(loaded_model_info.location), + "precision": str(loaded_model_info.precision), }, ) diff --git a/invokeai/app/services/model_manager/model_manager_base.py b/invokeai/app/services/model_manager/model_manager_base.py index a9b53ae2242..f888c0ec973 100644 --- a/invokeai/app/services/model_manager/model_manager_base.py +++ b/invokeai/app/services/model_manager/model_manager_base.py @@ -14,8 +14,8 @@ from invokeai.backend.model_management import ( AddModelResult, BaseModelType, + LoadedModelInfo, MergeInterpolationMethod, - ModelInfo, ModelType, SchedulerPredictionType, SubModelType, @@ -48,7 +48,7 @@ def get_model( model_type: ModelType, submodel: Optional[SubModelType] = None, context_data: Optional[InvocationContextData] = None, - ) -> ModelInfo: + ) -> LoadedModelInfo: """Retrieve the indicated model with name and type. submodel can be used to get a part (such as the vae) of a diffusers pipeline.""" diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index b641dd3f1ed..c3712abf8e6 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -16,8 +16,8 @@ from invokeai.backend.model_management import ( AddModelResult, BaseModelType, + LoadedModelInfo, MergeInterpolationMethod, - ModelInfo, ModelManager, ModelMerger, ModelNotFoundException, @@ -98,7 +98,7 @@ def get_model( model_type: ModelType, submodel: Optional[SubModelType] = None, context_data: Optional[InvocationContextData] = None, - ) -> ModelInfo: + ) -> LoadedModelInfo: """ Retrieve the indicated model. submodel can be used to get a part (such as the vae) of a diffusers mode. @@ -114,7 +114,7 @@ def get_model( submodel=submodel, ) - model_info = self.mgr.get_model( + loaded_model_info = self.mgr.get_model( model_name, base_model, model_type, @@ -128,10 +128,10 @@ def get_model( base_model=base_model, model_type=model_type, submodel=submodel, - model_info=model_info, + loaded_model_info=loaded_model_info, ) - return model_info + return loaded_model_info def model_exists( self, @@ -273,7 +273,7 @@ def _emit_load_event( base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None, - model_info: Optional[ModelInfo] = None, + loaded_model_info: Optional[LoadedModelInfo] = None, ): if self._invoker is None: return @@ -281,7 +281,7 @@ def _emit_load_event( if self._invoker.services.queue.is_canceled(context_data.session_id): raise CanceledException() - if model_info: + if loaded_model_info: self._invoker.services.events.emit_model_load_completed( queue_id=context_data.queue_id, queue_item_id=context_data.queue_item_id, @@ -291,7 +291,7 @@ def _emit_load_event( base_model=base_model, model_type=model_type, submodel=submodel, - model_info=model_info, + loaded_model_info=loaded_model_info, ) else: self._invoker.services.events.emit_model_load_started( diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 1ca44b78625..68fb78c1430 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -13,7 +13,7 @@ from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID from invokeai.app.util.step_callback import stable_diffusion_step_callback -from invokeai.backend.model_management.model_manager import ModelInfo +from invokeai.backend.model_management.model_manager import LoadedModelInfo from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData @@ -272,14 +272,15 @@ def exists(self, model_name: str, base_model: BaseModelType, model_type: ModelTy def load( self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None - ) -> ModelInfo: + ) -> LoadedModelInfo: """ - Loads a model, returning its `ModelInfo` object. + Loads a model. :param model_name: The name of the model to get. :param base_model: The base model of the model to get. :param model_type: The type of the model to get. :param submodel: The submodel of the model to get. + :returns: An object representing the loaded model. """ # The model manager emits events as it loads the model. It needs the context data to build diff --git a/invokeai/backend/__init__.py b/invokeai/backend/__init__.py index ae9a12edbe2..54a1843d463 100644 --- a/invokeai/backend/__init__.py +++ b/invokeai/backend/__init__.py @@ -1,5 +1,12 @@ """ Initialization file for invokeai.backend """ -from .model_management import BaseModelType, ModelCache, ModelInfo, ModelManager, ModelType, SubModelType # noqa: F401 +from .model_management import ( # noqa: F401 + BaseModelType, + LoadedModelInfo, + ModelCache, + ModelManager, + ModelType, + SubModelType, +) from .model_management.models import SilenceWarnings # noqa: F401 diff --git a/invokeai/backend/model_management/__init__.py b/invokeai/backend/model_management/__init__.py index 03abf58eb46..d523a7a0c8d 100644 --- a/invokeai/backend/model_management/__init__.py +++ b/invokeai/backend/model_management/__init__.py @@ -3,7 +3,7 @@ Initialization file for invokeai.backend.model_management """ # This import must be first -from .model_manager import AddModelResult, ModelInfo, ModelManager, SchedulerPredictionType +from .model_manager import AddModelResult, LoadedModelInfo, ModelManager, SchedulerPredictionType from .lora import ModelPatcher, ONNXModelPatcher from .model_cache import ModelCache diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 362d8d3ff55..da74ca3fb58 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -271,7 +271,7 @@ class ConfigMeta(BaseModel):Loras, textual_inversion and controlnet models are n @dataclass -class ModelInfo: +class LoadedModelInfo: context: ModelLocker name: str base_model: BaseModelType @@ -450,7 +450,7 @@ def get_model( base_model: BaseModelType, model_type: ModelType, submodel_type: Optional[SubModelType] = None, - ) -> ModelInfo: + ) -> LoadedModelInfo: """Given a model named identified in models.yaml, return an ModelInfo object describing it. :param model_name: symbolic name of the model in models.yaml @@ -508,7 +508,7 @@ def get_model( model_hash = "" # TODO: - return ModelInfo( + return LoadedModelInfo( context=model_context, name=model_name, base_model=base_model, diff --git a/invokeai/backend/util/test_utils.py b/invokeai/backend/util/test_utils.py index 09b9de9e984..685603cedc6 100644 --- a/invokeai/backend/util/test_utils.py +++ b/invokeai/backend/util/test_utils.py @@ -7,7 +7,7 @@ from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.backend.install.model_install_backend import ModelInstall -from invokeai.backend.model_management.model_manager import ModelInfo +from invokeai.backend.model_management.model_manager import LoadedModelInfo from invokeai.backend.model_management.models.base import BaseModelType, ModelNotFoundException, ModelType, SubModelType @@ -34,8 +34,8 @@ def install_and_load_model( base_model: BaseModelType, model_type: ModelType, submodel_type: Optional[SubModelType] = None, -) -> ModelInfo: - """Install a model if it is not already installed, then get the ModelInfo for that model. +) -> LoadedModelInfo: + """Install a model if it is not already installed, then get the LoadedModelInfo for that model. This is intended as a utility function for tests. @@ -49,9 +49,9 @@ def install_and_load_model( submodel_type (Optional[SubModelType]): The submodel type, forwarded to ModelManager.get_model(...). Returns: - ModelInfo + LoadedModelInfo """ - # If the requested model is already installed, return its ModelInfo. + # If the requested model is already installed, return its LoadedModelInfo. with contextlib.suppress(ModelNotFoundException): return model_installer.mgr.get_model(model_name, base_model, model_type, submodel_type) diff --git a/invokeai/invocation_api/__init__.py b/invokeai/invocation_api/__init__.py index e80bc26a003..2d3ceca11e2 100644 --- a/invokeai/invocation_api/__init__.py +++ b/invokeai/invocation_api/__init__.py @@ -52,7 +52,7 @@ from invokeai.app.services.image_records.image_records_common import ImageCategory from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID -from invokeai.backend.model_management.model_manager import ModelInfo +from invokeai.backend.model_management.model_manager import LoadedModelInfo from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( @@ -121,7 +121,7 @@ # invokeai.app.services.config.config_default "InvokeAIAppConfig", # invokeai.backend.model_management.model_manager - "ModelInfo", + "LoadedModelInfo", # invokeai.backend.model_management.models.base "BaseModelType", "ModelType", From 8fb77e431e8c5c1e8863ee7d76388f061592eca4 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 11 Feb 2024 09:39:36 +1100 Subject: [PATCH 063/100] chore(nodes): export model-related objects from invocation_api --- invokeai/invocation_api/__init__.py | 31 +++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/invokeai/invocation_api/__init__.py b/invokeai/invocation_api/__init__.py index 2d3ceca11e2..055dd12757d 100644 --- a/invokeai/invocation_api/__init__.py +++ b/invokeai/invocation_api/__init__.py @@ -28,6 +28,22 @@ WithMetadata, WithWorkflow, ) +from invokeai.app.invocations.model import ( + ClipField, + CLIPOutput, + LoraInfo, + LoraLoaderOutput, + LoRAModelField, + MainModelField, + ModelInfo, + ModelLoaderOutput, + SDXLLoraLoaderOutput, + UNetField, + UNetOutput, + VaeField, + VAEModelField, + VAEOutput, +) from invokeai.app.invocations.primitives import ( BooleanCollectionOutput, BooleanOutput, @@ -87,6 +103,21 @@ "UIType", "WithMetadata", "WithWorkflow", + # invokeai.app.invocations.model + "ModelInfo", + "LoraInfo", + "UNetField", + "ClipField", + "VaeField", + "MainModelField", + "LoRAModelField", + "VAEModelField", + "UNetOutput", + "VAEOutput", + "CLIPOutput", + "ModelLoaderOutput", + "LoraLoaderOutput", + "SDXLLoraLoaderOutput", # invokeai.app.invocations.primitives "BooleanCollectionOutput", "BooleanOutput", From 321b939d0efc66729b3e7ec8948e50798a83fa42 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 11 Feb 2024 09:43:36 +1100 Subject: [PATCH 064/100] chore(nodes): remove deprecation logic for nodes API --- .../app/services/shared/invocation_context.py | 112 ------------------ 1 file changed, 112 deletions(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 68fb78c1430..c68dc1140b2 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -1,7 +1,6 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Optional -from deprecated import deprecated from PIL.Image import Image from torch import Tensor @@ -334,30 +333,6 @@ def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_m ) -deprecation_version = "3.7.0" -removed_version = "3.8.0" - - -def get_deprecation_reason(property_name: str, alternative: Optional[str] = None) -> str: - msg = f"{property_name} is deprecated as of v{deprecation_version}. It will be removed in v{removed_version}." - if alternative is not None: - msg += f" Use {alternative} instead." - msg += " See PLACEHOLDER_URL for details." - return msg - - -# Deprecation docstrings template. I don't think we can implement these programmatically with -# __doc__ because the IDE won't see them. - -""" -**DEPRECATED as of v3.7.0** - -PROPERTY_NAME will be removed in v3.8.0. Use ALTERNATIVE instead. See PLACEHOLDER_URL for details. - -OG_DOCSTRING -""" - - class InvocationContext: """ The `InvocationContext` provides access to various services and data for the current invocation. @@ -397,93 +372,6 @@ def __init__( self._services = services """Provides access to the full application services. This is an internal API and may change without warning.""" - @property - @deprecated(version=deprecation_version, reason=get_deprecation_reason("`context.services`")) - def services(self) -> InvocationServices: - """ - **DEPRECATED as of v3.7.0** - - `context.services` will be removed in v3.8.0. See PLACEHOLDER_URL for details. - - The invocation services. - """ - return self._services - - @property - @deprecated( - version=deprecation_version, - reason=get_deprecation_reason("`context.graph_execution_state_id", "`context._data.session_id`"), - ) - def graph_execution_state_id(self) -> str: - """ - **DEPRECATED as of v3.7.0** - - `context.graph_execution_state_api` will be removed in v3.8.0. Use `context._data.session_id` instead. See PLACEHOLDER_URL for details. - - The ID of the session (aka graph execution state). - """ - return self._data.session_id - - @property - @deprecated( - version=deprecation_version, - reason=get_deprecation_reason("`context.queue_id`", "`context._data.queue_id`"), - ) - def queue_id(self) -> str: - """ - **DEPRECATED as of v3.7.0** - - `context.queue_id` will be removed in v3.8.0. Use `context._data.queue_id` instead. See PLACEHOLDER_URL for details. - - The ID of the queue. - """ - return self._data.queue_id - - @property - @deprecated( - version=deprecation_version, - reason=get_deprecation_reason("`context.queue_item_id`", "`context._data.queue_item_id`"), - ) - def queue_item_id(self) -> int: - """ - **DEPRECATED as of v3.7.0** - - `context.queue_item_id` will be removed in v3.8.0. Use `context._data.queue_item_id` instead. See PLACEHOLDER_URL for details. - - The ID of the queue item. - """ - return self._data.queue_item_id - - @property - @deprecated( - version=deprecation_version, - reason=get_deprecation_reason("`context.queue_batch_id`", "`context._data.batch_id`"), - ) - def queue_batch_id(self) -> str: - """ - **DEPRECATED as of v3.7.0** - - `context.queue_batch_id` will be removed in v3.8.0. Use `context._data.batch_id` instead. See PLACEHOLDER_URL for details. - - The ID of the batch. - """ - return self._data.batch_id - - @property - @deprecated( - version=deprecation_version, - reason=get_deprecation_reason("`context.workflow`", "`context._data.workflow`"), - ) - def workflow(self) -> Optional[WorkflowWithoutID]: - """ - **DEPRECATED as of v3.7.0** - - `context.workflow` will be removed in v3.8.0. Use `context._data.workflow` instead. See PLACEHOLDER_URL for details. - - The workflow associated with this queue item, if any. - """ - return self._data.workflow - def build_invocation_context( services: InvocationServices, From 1bbd13ead7dadef26db9b54febe7c19215723fcf Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 11 Feb 2024 09:51:25 +1100 Subject: [PATCH 065/100] chore(nodes): "SAMPLER_NAME_VALUES" -> "SCHEDULER_NAME_VALUES" This was named inaccurately. --- invokeai/app/invocations/constants.py | 7 +++++++ invokeai/app/invocations/latent.py | 10 ++++------ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/invokeai/app/invocations/constants.py b/invokeai/app/invocations/constants.py index 95b16f0d057..795e7a3b604 100644 --- a/invokeai/app/invocations/constants.py +++ b/invokeai/app/invocations/constants.py @@ -1,3 +1,7 @@ +from typing import Literal + +from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP + LATENT_SCALE_FACTOR = 8 """ HACK: Many nodes are currently hard-coded to use a fixed latent scale factor of 8. This is fragile, and will need to @@ -5,3 +9,6 @@ factor is hard-coded to a literal '8' rather than using this constant. The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1. """ + +SCHEDULER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())] +"""A literal type representing the valid scheduler names.""" diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index fedfc38402d..69e3f055ca8 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -23,7 +23,7 @@ from pydantic import field_validator from torchvision.transforms.functional import resize as tv_resize -from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR +from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES from invokeai.app.invocations.fields import ( ConditioningField, DenoiseMaskField, @@ -78,12 +78,10 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device()) -SAMPLER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())] - @invocation_output("scheduler_output") class SchedulerOutput(BaseInvocationOutput): - scheduler: SAMPLER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler) + scheduler: SCHEDULER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler) @invocation( @@ -96,7 +94,7 @@ class SchedulerOutput(BaseInvocationOutput): class SchedulerInvocation(BaseInvocation): """Selects a scheduler.""" - scheduler: SAMPLER_NAME_VALUES = InputField( + scheduler: SCHEDULER_NAME_VALUES = InputField( default="euler", description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler, @@ -234,7 +232,7 @@ class DenoiseLatentsInvocation(BaseInvocation): description=FieldDescriptions.denoising_start, ) denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end) - scheduler: SAMPLER_NAME_VALUES = InputField( + scheduler: SCHEDULER_NAME_VALUES = InputField( default="euler", description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler, From 6c4eeaa5692466df322d300bed8fa4d97c8f5e88 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 11 Feb 2024 10:06:53 +1100 Subject: [PATCH 066/100] feat(nodes): add more missing exports to invocation_api Crawled through a few custom nodes to figure out what I had missed. --- invokeai/invocation_api/__init__.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/invokeai/invocation_api/__init__.py b/invokeai/invocation_api/__init__.py index 055dd12757d..e110b5a2db3 100644 --- a/invokeai/invocation_api/__init__.py +++ b/invokeai/invocation_api/__init__.py @@ -7,9 +7,11 @@ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, + Classification, invocation, invocation_output, ) +from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES from invokeai.app.invocations.fields import ( BoardField, ColorField, @@ -28,6 +30,8 @@ WithMetadata, WithWorkflow, ) +from invokeai.app.invocations.latent import SchedulerOutput +from invokeai.app.invocations.metadata import MetadataItemField, MetadataItemOutput, MetadataOutput from invokeai.app.invocations.model import ( ClipField, CLIPOutput, @@ -68,6 +72,7 @@ from invokeai.app.services.image_records.image_records_common import ImageCategory from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID +from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.backend.model_management.model_manager import LoadedModelInfo from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState @@ -77,11 +82,14 @@ ExtraConditioningInfo, SDXLConditioningInfo, ) +from invokeai.backend.util.devices import CPU_DEVICE, CUDA_DEVICE, MPS_DEVICE, choose_precision, choose_torch_device +from invokeai.version import __version__ __all__ = [ # invokeai.app.invocations.baseinvocation "BaseInvocation", "BaseInvocationOutput", + "Classification", "invocation", "invocation_output", # invokeai.app.services.shared.invocation_context @@ -103,6 +111,12 @@ "UIType", "WithMetadata", "WithWorkflow", + # invokeai.app.invocations.latent + "SchedulerOutput", + # invokeai.app.invocations.metadata + "MetadataItemField", + "MetadataItemOutput", + "MetadataOutput", # invokeai.app.invocations.model "ModelInfo", "LoraInfo", @@ -157,4 +171,17 @@ "BaseModelType", "ModelType", "SubModelType", + # invokeai.app.invocations.constants + "SCHEDULER_NAME_VALUES", + # invokeai.version + "__version__", + # invokeai.backend.util.devices + "choose_precision", + "choose_torch_device", + "CPU_DEVICE", + "CUDA_DEVICE", + "MPS_DEVICE", + # invokeai.app.util.misc + "SEED_MAX", + "get_random_seed", ] From c76a6bd65f75a032a458220c55310b0592dfc90d Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 13 Feb 2024 15:02:30 +1100 Subject: [PATCH 067/100] chore(ui): regen types --- .../frontend/web/src/services/api/schema.ts | 1892 +---------------- 1 file changed, 67 insertions(+), 1825 deletions(-) diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 45358ed97d5..1599b310c9a 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -680,70 +680,6 @@ export type components = { */ type: "add"; }; - /** - * Adjust Image Hue Plus - * @description Adjusts the Hue of an image by rotating it in the selected color space - */ - AdjustImageHuePlusInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The image to adjust */ - image?: components["schemas"]["ImageField"]; - /** - * Space - * @description Color space in which to rotate hue by polar coords (*: non-invertible) - * @default Okhsl - * @enum {string} - */ - space?: "HSV / HSL / RGB" | "Okhsl" | "Okhsv" | "*Oklch / Oklab" | "*LCh / CIELab" | "*UPLab (w/CIELab_to_UPLab.icc)"; - /** - * Degrees - * @description Degrees by which to rotate image hue - * @default 0 - */ - degrees?: number; - /** - * Preserve Lightness - * @description Whether to preserve CIELAB lightness values - * @default false - */ - preserve_lightness?: boolean; - /** - * Ok Adaptive Gamut - * @description Higher preserves chroma at the expense of lightness (Oklab) - * @default 0.05 - */ - ok_adaptive_gamut?: number; - /** - * Ok High Precision - * @description Use more steps in computing gamut (Oklab/Okhsv/Okhsl) - * @default true - */ - ok_high_precision?: boolean; - /** - * type - * @default img_hue_adjust_plus - * @constant - */ - type: "img_hue_adjust_plus"; - }; /** * AppConfig * @description App Config Response @@ -1454,39 +1390,6 @@ export type components = { */ type: "boolean_output"; }; - /** - * BRIA AI Background Removal - * @description Uses the new Bria 1.4 model to remove backgrounds from images. - */ - BriaRemoveBackgroundInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The image to crop */ - image?: components["schemas"]["ImageField"]; - /** - * type - * @default bria_bg_remove - * @constant - */ - type: "bria_bg_remove"; - }; /** * CLIPOutput * @description Base class for invocations that output a CLIP field @@ -1581,282 +1484,6 @@ export type components = { /** @description Base model (usually 'Any') */ base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; }; - /** - * CMYK Color Separation - * @description Get color images from a base color and two others that subtractively mix to obtain it - */ - CMYKColorSeparationInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** - * Width - * @description Desired image width - * @default 512 - */ - width?: number; - /** - * Height - * @description Desired image height - * @default 512 - */ - height?: number; - /** - * C Value - * @description Desired final cyan value - * @default 0 - */ - c_value?: number; - /** - * M Value - * @description Desired final magenta value - * @default 25 - */ - m_value?: number; - /** - * Y Value - * @description Desired final yellow value - * @default 28 - */ - y_value?: number; - /** - * K Value - * @description Desired final black value - * @default 76 - */ - k_value?: number; - /** - * C Split - * @description Desired cyan split point % [0..1.0] - * @default 0.5 - */ - c_split?: number; - /** - * M Split - * @description Desired magenta split point % [0..1.0] - * @default 1 - */ - m_split?: number; - /** - * Y Split - * @description Desired yellow split point % [0..1.0] - * @default 0 - */ - y_split?: number; - /** - * K Split - * @description Desired black split point % [0..1.0] - * @default 0.5 - */ - k_split?: number; - /** - * Profile - * @description CMYK Color Profile - * @default Default - * @enum {string} - */ - profile?: "Default" | "PIL"; - /** - * type - * @default cmyk_separation - * @constant - */ - type: "cmyk_separation"; - }; - /** - * CMYK Merge - * @description Merge subtractive color channels (CMYK+alpha) - */ - CMYKMergeInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The c channel */ - c_channel?: components["schemas"]["ImageField"] | null; - /** @description The m channel */ - m_channel?: components["schemas"]["ImageField"] | null; - /** @description The y channel */ - y_channel?: components["schemas"]["ImageField"] | null; - /** @description The k channel */ - k_channel?: components["schemas"]["ImageField"] | null; - /** @description The alpha channel */ - alpha_channel?: components["schemas"]["ImageField"] | null; - /** - * Profile - * @description CMYK Color Profile - * @default Default - * @enum {string} - */ - profile?: "Default" | "PIL"; - /** - * type - * @default cmyk_merge - * @constant - */ - type: "cmyk_merge"; - }; - /** - * CMYKSeparationOutput - * @description Base class for invocations that output four L-mode images (C, M, Y, K) - */ - CMYKSeparationOutput: { - /** @description Blank image of the specified color */ - color_image: components["schemas"]["ImageField"]; - /** - * Width - * @description The width of the image in pixels - */ - width: number; - /** - * Height - * @description The height of the image in pixels - */ - height: number; - /** @description Blank image of the first separated color */ - part_a: components["schemas"]["ImageField"]; - /** - * Rgb Red A - * @description R value of color part A - */ - rgb_red_a: number; - /** - * Rgb Green A - * @description G value of color part A - */ - rgb_green_a: number; - /** - * Rgb Blue A - * @description B value of color part A - */ - rgb_blue_a: number; - /** @description Blank image of the second separated color */ - part_b: components["schemas"]["ImageField"]; - /** - * Rgb Red B - * @description R value of color part B - */ - rgb_red_b: number; - /** - * Rgb Green B - * @description G value of color part B - */ - rgb_green_b: number; - /** - * Rgb Blue B - * @description B value of color part B - */ - rgb_blue_b: number; - /** - * type - * @default cmyk_separation_output - * @constant - */ - type: "cmyk_separation_output"; - }; - /** - * CMYK Split - * @description Split an image into subtractive color channels (CMYK+alpha) - */ - CMYKSplitInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The image to halftone */ - image?: components["schemas"]["ImageField"]; - /** - * Profile - * @description CMYK Color Profile - * @default Default - * @enum {string} - */ - profile?: "Default" | "PIL"; - /** - * type - * @default cmyk_split - * @constant - */ - type: "cmyk_split"; - }; - /** - * CMYKSplitOutput - * @description Base class for invocations that output four L-mode images (C, M, Y, K) - */ - CMYKSplitOutput: { - /** @description Grayscale image of the cyan channel */ - c_channel: components["schemas"]["ImageField"]; - /** @description Grayscale image of the magenta channel */ - m_channel: components["schemas"]["ImageField"]; - /** @description Grayscale image of the yellow channel */ - y_channel: components["schemas"]["ImageField"]; - /** @description Grayscale image of the k channel */ - k_channel: components["schemas"]["ImageField"]; - /** @description Grayscale image of the alpha channel */ - alpha_channel: components["schemas"]["ImageField"]; - /** - * Width - * @description The width of the image in pixels - */ - width: number; - /** - * Height - * @description The height of the image in pixels - */ - height: number; - /** - * type - * @default cmyk_split_output - * @constant - */ - type: "cmyk_split_output"; - }; /** * CV2 Infill * @description Infills transparent areas of an image using OpenCV Inpainting @@ -3491,6 +3118,8 @@ export type components = { * @description Generates an openpose pose from an image using DWPose */ DWOpenposeImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -4014,11 +3643,20 @@ export type components = { */ priority: number; }; + /** ExposedField */ + ExposedField: { + /** Nodeid */ + nodeId: string; + /** Fieldname */ + fieldName: string; + }; /** - * Equivalent Achromatic Lightness - * @description Calculate Equivalent Achromatic Lightness from image + * FaceIdentifier + * @description Outputs an image with detected face IDs printed on each face. For use with other FaceTools. */ - EquivalentAchromaticLightnessInvocation: { + FaceIdentifierInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -4038,54 +3676,12 @@ export type components = { * @default true */ use_cache?: boolean; - /** @description Image from which to get channel */ + /** @description Image to face detect */ image?: components["schemas"]["ImageField"]; /** - * type - * @default ealightness - * @constant - */ - type: "ealightness"; - }; - /** ExposedField */ - ExposedField: { - /** Nodeid */ - nodeId: string; - /** Fieldname */ - fieldName: string; - }; - /** - * FaceIdentifier - * @description Outputs an image with detected face IDs printed on each face. For use with other FaceTools. - */ - FaceIdentifierInvocation: { - /** @description The board to save the image to */ - board?: components["schemas"]["BoardField"] | null; - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description Image to face detect */ - image?: components["schemas"]["ImageField"]; - /** - * Minimum Confidence - * @description Minimum confidence for face detection (lower if detection is failing) - * @default 0.5 + * Minimum Confidence + * @description Minimum confidence for face detection (lower if detection is failing) + * @default 0.5 */ minimum_confidence?: number; /** @@ -4301,39 +3897,6 @@ export type components = { */ y: number; }; - /** - * Flatten Histogram (Grayscale) - * @description Scales the values of an L-mode image by scaling them to the full range 0..255 in equal proportions - */ - FlattenHistogramMono: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description Single-channel image for which to flatten the histogram */ - image?: components["schemas"]["ImageField"]; - /** - * type - * @default flatten_histogram_mono - * @constant - */ - type: "flatten_histogram_mono"; - }; /** * Float Collection Primitive * @description A collection of float primitive values @@ -4683,7 +4246,7 @@ export type components = { * @description The nodes in this graph */ nodes?: { - [key: string]: components["schemas"]["ImageDilateOrErodeInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["HandDepthMeshGraphormerProcessor"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["TextToMaskClipsegInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LinearUIOutputInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["OffsetLatentsInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["CMYKColorSeparationInvocation"] | components["schemas"]["NoiseImage2DInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["EquivalentAchromaticLightnessInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageBlendInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["ImageRotateInvocation"] | components["schemas"]["ShadowsHighlightsMidtonesMaskInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["BriaRemoveBackgroundInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["ImageValueThresholdsInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["NoiseSpectralInvocation"] | components["schemas"]["TextMaskInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["MaskedBlendLatentsInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["CMYKMergeInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["TextToMaskClipsegAdvancedInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageCompositorInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["FlattenHistogramMono"] | components["schemas"]["AdjustImageHuePlusInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["CMYKSplitInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["ImageEnhanceInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["LatentConsistencyInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["ImageOffsetInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["InfillColorInvocation"]; + [key: string]: components["schemas"]["ImageCropInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["IdealSizeInvocation"]; }; /** * Edges @@ -4720,7 +4283,7 @@ export type components = { * @description The results of node executions */ results: { - [key: string]: components["schemas"]["LoraLoaderOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["CMYKSplitOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["CMYKSeparationOutput"] | components["schemas"]["String2Output"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["ONNXModelLoaderOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["HandDepthOutput"] | components["schemas"]["ShadowsHighlightsMidtonesMasksOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["ColorCollectionOutput"]; + [key: string]: components["schemas"]["MetadataOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["String2Output"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["ClipSkipInvocationOutput"]; }; /** * Errors @@ -4811,83 +4374,6 @@ export type components = { /** Detail */ detail?: components["schemas"]["ValidationError"][]; }; - /** - * Hand Depth w/ MeshGraphormer - * @description Generate hand depth maps to inpaint with using ControlNet - */ - HandDepthMeshGraphormerProcessor: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The image to process */ - image?: components["schemas"]["ImageField"]; - /** - * Resolution - * @description Pixel resolution for output image - * @default 512 - */ - resolution?: number; - /** - * Mask Padding - * @description Amount to pad the hand mask by - * @default 30 - */ - mask_padding?: number; - /** - * Offload - * @description Offload model after usage - * @default false - */ - offload?: boolean; - /** - * type - * @default hand_depth_mesh_graphormer_image_processor - * @constant - */ - type: "hand_depth_mesh_graphormer_image_processor"; - }; - /** - * HandDepthOutput - * @description Base class for to output Meshgraphormer results - */ - HandDepthOutput: { - /** @description Improved hands depth map */ - image: components["schemas"]["ImageField"]; - /** @description Hands area mask */ - mask: components["schemas"]["ImageField"]; - /** - * Width - * @description The width of the depth map in pixels - */ - width: number; - /** - * Height - * @description The height of the depth map in pixels - */ - height: number; - /** - * type - * @default meshgraphormer_output - * @constant - */ - type: "meshgraphormer_output"; - }; /** * HED (softedge) Processor * @description Applies HED edge detection to image @@ -5260,87 +4746,6 @@ export type components = { */ type: "ideal_size_output"; }; - /** - * Image Layer Blend - * @description Blend two images together, with optional opacity, mask, and blend modes - */ - ImageBlendInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The top image to blend */ - layer_upper?: components["schemas"]["ImageField"]; - /** - * Blend Mode - * @description Available blend modes - * @default Normal - * @enum {string} - */ - blend_mode?: "Normal" | "Lighten Only" | "Darken Only" | "Lighten Only (EAL)" | "Darken Only (EAL)" | "Hue" | "Saturation" | "Color" | "Luminosity" | "Linear Dodge (Add)" | "Subtract" | "Multiply" | "Divide" | "Screen" | "Overlay" | "Linear Burn" | "Difference" | "Hard Light" | "Soft Light" | "Vivid Light" | "Linear Light" | "Color Burn" | "Color Dodge"; - /** - * Opacity - * @description Desired opacity of the upper layer - * @default 1 - */ - opacity?: number; - /** @description Optional mask, used to restrict areas from blending */ - mask?: components["schemas"]["ImageField"] | null; - /** - * Fit To Width - * @description Scale upper layer to fit base width - * @default false - */ - fit_to_width?: boolean; - /** - * Fit To Height - * @description Scale upper layer to fit base height - * @default true - */ - fit_to_height?: boolean; - /** @description The bottom image to blend */ - layer_base?: components["schemas"]["ImageField"]; - /** - * Color Space - * @description Available color spaces for blend computations - * @default Linear RGB - * @enum {string} - */ - color_space?: "RGB" | "Linear RGB" | "HSL (RGB)" | "HSV (RGB)" | "Okhsl" | "Okhsv" | "Oklch (Oklab)" | "LCh (CIELab)"; - /** - * Adaptive Gamut - * @description Adaptive gamut clipping (0=off). Higher prioritizes chroma over lightness - * @default 0 - */ - adaptive_gamut?: number; - /** - * High Precision - * @description Use more steps in computing gamut when possible - * @default true - */ - high_precision?: boolean; - /** - * type - * @default img_blend - * @constant - */ - type: "img_blend"; - }; /** * Blur Image * @description Blurs an image @@ -5594,77 +4999,6 @@ export type components = { */ type: "image_collection_output"; }; - /** - * Image Compositor - * @description Removes backdrop from subject image then overlays subject on background image - */ - ImageCompositorInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description Image of the subject on a plain monochrome background */ - image_subject?: components["schemas"]["ImageField"]; - /** @description Image of a background scene */ - image_background?: components["schemas"]["ImageField"]; - /** - * Chroma Key - * @description Can be empty for corner flood select, or CSS-3 color or tuple - * @default - */ - chroma_key?: string; - /** - * Threshold - * @description Subject isolation flood-fill threshold - * @default 50 - */ - threshold?: number; - /** - * Fill X - * @description Scale base subject image to fit background width - * @default false - */ - fill_x?: boolean; - /** - * Fill Y - * @description Scale base subject image to fit background height - * @default true - */ - fill_y?: boolean; - /** - * X Offset - * @description x-offset for the subject - * @default 0 - */ - x_offset?: number; - /** - * Y Offset - * @description y-offset for the subject - * @default 0 - */ - y_offset?: number; - /** - * type - * @default img_composite - * @constant - */ - type: "img_composite"; - }; /** * Convert Image Mode * @description Converts an image to a different mode. @@ -5847,10 +5181,23 @@ export type components = { board_id?: string | null; }; /** - * Image Dilate or Erode - * @description Dilate (expand) or erode (contract) an image + * ImageField + * @description An image primitive field + */ + ImageField: { + /** + * Image Name + * @description The name of the image + */ + image_name: string; + }; + /** + * Adjust Image Hue + * @description Adjusts the Hue of an image. */ - ImageDilateOrErodeInvocation: { + ImageHueAdjustmentInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -5870,158 +5217,24 @@ export type components = { * @default true */ use_cache?: boolean; - /** @description The image from which to create a mask */ + /** @description The image to adjust */ image?: components["schemas"]["ImageField"]; /** - * Lightness Only - * @description If true, only applies to image lightness (CIELa*b*) - * @default false - */ - lightness_only?: boolean; - /** - * Radius W - * @description Width (in pixels) by which to dilate(expand) or erode (contract) the image - * @default 4 - */ - radius_w?: number; - /** - * Radius H - * @description Height (in pixels) by which to dilate(expand) or erode (contract) the image - * @default 4 - */ - radius_h?: number; - /** - * Mode - * @description How to operate on the image - * @default Dilate - * @enum {string} + * Hue + * @description The degrees by which to rotate the hue, 0-360 + * @default 0 */ - mode?: "Dilate" | "Erode"; + hue?: number; /** * type - * @default img_dilate_erode + * @default img_hue_adjust * @constant */ - type: "img_dilate_erode"; + type: "img_hue_adjust"; }; /** - * Enhance Image - * @description Applies processing from PIL's ImageEnhance module. - */ - ImageEnhanceInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The image for which to apply processing */ - image?: components["schemas"]["ImageField"]; - /** - * Invert - * @description Whether to invert the image colors - * @default false - */ - invert?: boolean; - /** - * Color - * @description Color enhancement factor - * @default 1 - */ - color?: number; - /** - * Contrast - * @description Contrast enhancement factor - * @default 1 - */ - contrast?: number; - /** - * Brightness - * @description Brightness enhancement factor - * @default 1 - */ - brightness?: number; - /** - * Sharpness - * @description Sharpness enhancement factor - * @default 1 - */ - sharpness?: number; - /** - * type - * @default img_enhance - * @constant - */ - type: "img_enhance"; - }; - /** - * ImageField - * @description An image primitive field - */ - ImageField: { - /** - * Image Name - * @description The name of the image - */ - image_name: string; - }; - /** - * Adjust Image Hue - * @description Adjusts the Hue of an image. - */ - ImageHueAdjustmentInvocation: { - /** @description The board to save the image to */ - board?: components["schemas"]["BoardField"] | null; - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The image to adjust */ - image?: components["schemas"]["ImageField"]; - /** - * Hue - * @description The degrees by which to rotate the hue, 0-360 - * @default 0 - */ - hue?: number; - /** - * type - * @default img_hue_adjust - * @constant - */ - type: "img_hue_adjust"; - }; - /** - * Inverse Lerp Image - * @description Inverse linear interpolation of all pixels of an image + * Inverse Lerp Image + * @description Inverse linear interpolation of all pixels of an image */ ImageInverseLerpInvocation: { /** @description The board to save the image to */ @@ -6216,57 +5429,6 @@ export type components = { */ type: "img_nsfw"; }; - /** - * Offset Image - * @description Offsets an image by a given percentage (or pixel amount). - */ - ImageOffsetInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** - * As Pixels - * @description Interpret offsets as pixels rather than percentages - * @default false - */ - as_pixels?: boolean; - /** @description Image to be offset */ - image?: components["schemas"]["ImageField"]; - /** - * X Offset - * @description x-offset for the subject - * @default 0.5 - */ - x_offset?: number; - /** - * Y Offset - * @description y-offset for the subject - * @default 0.5 - */ - y_offset?: number; - /** - * type - * @default offset_image - * @constant - */ - type: "offset_image"; - }; /** * ImageOutput * @description Base class for nodes that output a single image @@ -6432,63 +5594,6 @@ export type components = { */ type: "img_resize"; }; - /** - * Rotate/Flip Image - * @description Rotates an image by a given angle (in degrees clockwise). - */ - ImageRotateInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description Image to be rotated clockwise */ - image?: components["schemas"]["ImageField"]; - /** - * Degrees - * @description Angle (in degrees clockwise) by which to rotate - * @default 90 - */ - degrees?: number; - /** - * Expand To Fit - * @description If true, extends the image boundary to fit the rotated content - * @default true - */ - expand_to_fit?: boolean; - /** - * Flip Horizontal - * @description If true, flips the image horizontally - * @default false - */ - flip_horizontal?: boolean; - /** - * Flip Vertical - * @description If true, flips the image vertically - * @default false - */ - flip_vertical?: boolean; - /** - * type - * @default rotate_image - * @constant - */ - type: "rotate_image"; - }; /** * Scale Image * @description Scales an image by a factor @@ -6603,69 +5708,6 @@ export type components = { */ thumbnail_url: string; }; - /** - * Image Value Thresholds - * @description Clip image to pure black/white past specified thresholds - */ - ImageValueThresholdsInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The image from which to create a mask */ - image?: components["schemas"]["ImageField"]; - /** - * Invert Output - * @description Make light areas dark and vice versa - * @default false - */ - invert_output?: boolean; - /** - * Renormalize Values - * @description Rescale remaining values from minimum to maximum - * @default false - */ - renormalize_values?: boolean; - /** - * Lightness Only - * @description If true, only applies to image lightness (CIELa*b*) - * @default false - */ - lightness_only?: boolean; - /** - * Threshold Upper - * @description Threshold above which will be set to full value - * @default 0.5 - */ - threshold_upper?: number; - /** - * Threshold Lower - * @description Threshold below which will be set to minimum value - * @default 0.5 - */ - threshold_lower?: number; - /** - * type - * @default img_val_thresholds - * @constant - */ - type: "img_val_thresholds"; - }; /** * Add Invisible Watermark * @description Add an invisible watermark to an image @@ -8025,47 +7067,6 @@ export type components = { */ type: "tomask"; }; - /** - * Blend Latents/Noise (Masked) - * @description Blend two latents using a given alpha and mask. Latents must have same size. - */ - MaskedBlendLatentsInvocation: { - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description Latents tensor */ - latents_a?: components["schemas"]["LatentsField"]; - /** @description Latents tensor */ - latents_b?: components["schemas"]["LatentsField"]; - /** @description Mask for blending in latents B */ - mask?: components["schemas"]["ImageField"]; - /** - * Alpha - * @description Blending factor. 0.0 = use input A only, 1.0 = use input B only, 0.5 = 50% mix of input A and input B. - * @default 0.5 - */ - alpha?: number; - /** - * type - * @default lmblend - * @constant - */ - type: "lmblend"; - }; /** * Mediapipe Face Processor * @description Applies mediapipe face processing to image @@ -8679,88 +7680,8 @@ export type components = { value: string | number; }; /** - * 2D Noise Image - * @description Creates an image of 2D Noise approximating the desired characteristics - */ - NoiseImage2DInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** - * Noise Type - * @description Desired noise spectral characteristics - * @default White - * @enum {string} - */ - noise_type?: "White" | "Red" | "Blue" | "Green"; - /** - * Width - * @description Desired image width - * @default 512 - */ - width?: number; - /** - * Height - * @description Desired image height - * @default 512 - */ - height?: number; - /** - * Seed - * @description Seed for noise generation - * @default 0 - */ - seed?: number; - /** - * Iterations - * @description Noise approx. iterations - * @default 15 - */ - iterations?: number; - /** - * Blur Threshold - * @description Threshold used in computing noise (lower is better/slower) - * @default 0.2 - */ - blur_threshold?: number; - /** - * Sigma Red - * @description Sigma for strong gaussian blur LPF for red/green - * @default 3 - */ - sigma_red?: number; - /** - * Sigma Blue - * @description Sigma for weak gaussian blur HPF for blue/green - * @default 1 - */ - sigma_blue?: number; - /** - * type - * @default noiseimg_2d - * @constant - */ - type: "noiseimg_2d"; - }; - /** - * Noise - * @description Generates latent noise. + * Noise + * @description Generates latent noise. */ NoiseInvocation: { /** @@ -8835,84 +7756,6 @@ export type components = { */ type: "noise_output"; }; - /** - * Noise (Spectral characteristics) - * @description Creates an image of 2D Noise approximating the desired characteristics - */ - NoiseSpectralInvocation: { - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** - * Noise Type - * @description Desired noise spectral characteristics - * @default White - * @enum {string} - */ - noise_type?: "White" | "Red" | "Blue" | "Green"; - /** - * Width - * @description Desired image width - * @default 512 - */ - width?: number; - /** - * Height - * @description Desired image height - * @default 512 - */ - height?: number; - /** - * Seed - * @description Seed for noise generation - * @default 0 - */ - seed?: number; - /** - * Iterations - * @description Noise approx. iterations - * @default 15 - */ - iterations?: number; - /** - * Blur Threshold - * @description Threshold used in computing noise (lower is better/slower) - * @default 0.2 - */ - blur_threshold?: number; - /** - * Sigma Red - * @description Sigma for strong gaussian blur LPF for red/green - * @default 3 - */ - sigma_red?: number; - /** - * Sigma Blue - * @description Sigma for weak gaussian blur HPF for blue/green - * @default 1 - */ - sigma_blue?: number; - /** - * type - * @default noise_spectral - * @constant - */ - type: "noise_spectral"; - }; /** * Normal BAE Processor * @description Applies NormalBae processing to image @@ -8960,107 +7803,6 @@ export type components = { */ type: "normalbae_image_processor"; }; - /** - * ONNX Latents to Image - * @description Generates an image from latents. - */ - ONNXLatentsToImageInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description Denoised latents tensor */ - latents?: components["schemas"]["LatentsField"]; - /** @description VAE */ - vae?: components["schemas"]["VaeField"]; - /** - * type - * @default l2i_onnx - * @constant - */ - type: "l2i_onnx"; - }; - /** - * ONNXModelLoaderOutput - * @description Model loader output - */ - ONNXModelLoaderOutput: { - /** - * UNet - * @description UNet (scheduler, LoRAs) - */ - unet?: components["schemas"]["UNetField"]; - /** - * CLIP - * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count - */ - clip?: components["schemas"]["ClipField"]; - /** - * VAE Decoder - * @description VAE - */ - vae_decoder?: components["schemas"]["VaeField"]; - /** - * VAE Encoder - * @description VAE - */ - vae_encoder?: components["schemas"]["VaeField"]; - /** - * type - * @default model_loader_output_onnx - * @constant - */ - type: "model_loader_output_onnx"; - }; - /** ONNX Prompt (Raw) */ - ONNXPromptInvocation: { - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** - * Prompt - * @description Raw prompt text (no parsing) - * @default - */ - prompt?: string; - /** @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count */ - clip?: components["schemas"]["ClipField"]; - /** - * type - * @default prompt_onnx - * @constant - */ - type: "prompt_onnx"; - }; /** * ONNXSD1Config * @description Model config for ONNX format models based on sd-1. @@ -9242,117 +7984,6 @@ export type components = { /** Upcast Attention */ upcast_attention: boolean; }; - /** - * ONNX Text to Latents - * @description Generates latents from conditionings. - */ - ONNXTextToLatentsInvocation: { - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description Positive conditioning tensor */ - positive_conditioning?: components["schemas"]["ConditioningField"]; - /** @description Negative conditioning tensor */ - negative_conditioning?: components["schemas"]["ConditioningField"]; - /** @description Noise tensor */ - noise?: components["schemas"]["LatentsField"]; - /** - * Steps - * @description Number of steps to run - * @default 10 - */ - steps?: number; - /** - * Cfg Scale - * @description Classifier-Free Guidance scale - * @default 7.5 - */ - cfg_scale?: number | number[]; - /** - * Scheduler - * @description Scheduler to use during inference - * @default euler - * @enum {string} - */ - scheduler?: "ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "lcm"; - /** - * Precision - * @description Precision to use - * @default tensor(float16) - * @enum {string} - */ - precision?: "tensor(bool)" | "tensor(int8)" | "tensor(uint8)" | "tensor(int16)" | "tensor(uint16)" | "tensor(int32)" | "tensor(uint32)" | "tensor(int64)" | "tensor(uint64)" | "tensor(float16)" | "tensor(float)" | "tensor(double)"; - /** @description UNet (scheduler, LoRAs) */ - unet?: components["schemas"]["UNetField"]; - /** - * Control - * @description ControlNet(s) to apply - */ - control?: components["schemas"]["ControlField"] | components["schemas"]["ControlField"][]; - /** - * type - * @default t2l_onnx - * @constant - */ - type: "t2l_onnx"; - }; - /** - * Offset Latents - * @description Offsets a latents tensor by a given percentage of height/width. - */ - OffsetLatentsInvocation: { - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description Latents tensor */ - latents?: components["schemas"]["LatentsField"]; - /** - * X Offset - * @description Approx percentage to offset (H) - * @default 0.5 - */ - x_offset?: number; - /** - * Y Offset - * @description Approx percentage to offset (V) - * @default 0.5 - */ - y_offset?: number; - /** - * type - * @default offset_latents - * @constant - */ - type: "offset_latents"; - }; /** OffsetPaginatedResults[BoardDTO] */ OffsetPaginatedResults_BoardDTO_: { /** @@ -9392,58 +8023,12 @@ export type components = { * Total * @description Total number of items in result */ - total: number; - /** - * Items - * @description Items - */ - items: components["schemas"]["ImageDTO"][]; - }; - /** - * OnnxModelField - * @description Onnx model field - */ - OnnxModelField: { - /** - * Model Name - * @description Name of the model - */ - model_name: string; - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** @description Model Type */ - model_type: components["schemas"]["invokeai__backend__model_management__models__base__ModelType"]; - }; - /** - * ONNX Main Model - * @description Loads a main model, outputting its submodels. - */ - OnnxModelLoaderInvocation: { - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description ONNX Main model (UNet, VAE, CLIP) to load */ - model: components["schemas"]["OnnxModelField"]; + total: number; /** - * type - * @default onnx_model_loader - * @constant + * Items + * @description Items */ - type: "onnx_model_loader"; + items: components["schemas"]["ImageDTO"][]; }; /** PaginatedResults[ModelSummary] */ PaginatedResults_ModelSummary_: { @@ -10852,106 +9437,6 @@ export type components = { */ total: number; }; - /** - * Shadows/Highlights/Midtones - * @description Extract a Shadows/Highlights/Midtones mask from an image - */ - ShadowsHighlightsMidtonesMaskInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description Image from which to extract mask */ - image?: components["schemas"]["ImageField"]; - /** - * Invert Output - * @description Off: white on black / On: black on white - * @default true - */ - invert_output?: boolean; - /** - * Highlight Threshold - * @description Threshold beyond which mask values will be at extremum - * @default 0.75 - */ - highlight_threshold?: number; - /** - * Upper Mid Threshold - * @description Threshold to which to extend mask border by 0..1 gradient - * @default 0.7 - */ - upper_mid_threshold?: number; - /** - * Lower Mid Threshold - * @description Threshold to which to extend mask border by 0..1 gradient - * @default 0.3 - */ - lower_mid_threshold?: number; - /** - * Shadow Threshold - * @description Threshold beyond which mask values will be at extremum - * @default 0.25 - */ - shadow_threshold?: number; - /** - * Mask Expand Or Contract - * @description Pixels to grow (or shrink) the mask areas - * @default 0 - */ - mask_expand_or_contract?: number; - /** - * Mask Blur - * @description Gaussian blur radius to apply to the masks - * @default 0 - */ - mask_blur?: number; - /** - * type - * @default shmmask - * @constant - */ - type: "shmmask"; - }; - /** ShadowsHighlightsMidtonesMasksOutput */ - ShadowsHighlightsMidtonesMasksOutput: { - /** @description Soft-edged highlights mask */ - highlights_mask?: components["schemas"]["ImageField"]; - /** @description Soft-edged midtones mask */ - midtones_mask?: components["schemas"]["ImageField"]; - /** @description Soft-edged shadows mask */ - shadows_mask?: components["schemas"]["ImageField"]; - /** - * Width - * @description Width of the input/outputs - */ - width: number; - /** - * Height - * @description Height of the input/outputs - */ - height: number; - /** - * type - * @default shmmask_output - * @constant - */ - type: "shmmask_output"; - }; /** * Show Image * @description Displays a provided image using the OS image viewer, and passes it forward in the pipeline. @@ -11833,249 +10318,6 @@ export type components = { /** Right */ right: number; }; - /** - * Text Mask - * @description Creates a 2D rendering of a text mask from a given font - */ - TextMaskInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** - * Width - * @description The width of the desired mask - * @default 512 - */ - width?: number; - /** - * Height - * @description The height of the desired mask - * @default 512 - */ - height?: number; - /** - * Text - * @description The text to render - * @default - */ - text?: string; - /** - * Font - * @description Path to a FreeType-supported TTF/OTF font file - * @default - */ - font?: string; - /** - * Size - * @description Desired point size of text to use - * @default 64 - */ - size?: number; - /** - * Angle - * @description Angle of rotation to apply to the text - * @default 0 - */ - angle?: number; - /** - * X Offset - * @description x-offset for text rendering - * @default 24 - */ - x_offset?: number; - /** - * Y Offset - * @description y-offset for text rendering - * @default 36 - */ - y_offset?: number; - /** - * Invert - * @description Whether to invert color of the output - * @default false - */ - invert?: boolean; - /** - * type - * @default text_mask - * @constant - */ - type: "text_mask"; - }; - /** - * Text to Mask Advanced (Clipseg) - * @description Uses the Clipseg model to generate an image mask from a text prompt - */ - TextToMaskClipsegAdvancedInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The image from which to create a mask */ - image?: components["schemas"]["ImageField"]; - /** - * Invert Output - * @description Off: white on black / On: black on white - * @default true - */ - invert_output?: boolean; - /** - * Prompt 1 - * @description First prompt with which to create a mask - */ - prompt_1?: string; - /** - * Prompt 2 - * @description Second prompt with which to create a mask (optional) - */ - prompt_2?: string; - /** - * Prompt 3 - * @description Third prompt with which to create a mask (optional) - */ - prompt_3?: string; - /** - * Prompt 4 - * @description Fourth prompt with which to create a mask (optional) - */ - prompt_4?: string; - /** - * Combine - * @description How to combine the results - * @default or - * @enum {string} - */ - combine?: "or" | "and" | "none (rgba multiplex)"; - /** - * Smoothing - * @description Radius of blur to apply before thresholding - * @default 4 - */ - smoothing?: number; - /** - * Subject Threshold - * @description Threshold above which is considered the subject - * @default 1 - */ - subject_threshold?: number; - /** - * Background Threshold - * @description Threshold below which is considered the background - * @default 0 - */ - background_threshold?: number; - /** - * type - * @default txt2mask_clipseg_adv - * @constant - */ - type: "txt2mask_clipseg_adv"; - }; - /** - * Text to Mask (Clipseg) - * @description Uses the Clipseg model to generate an image mask from a text prompt - */ - TextToMaskClipsegInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The image from which to create a mask */ - image?: components["schemas"]["ImageField"]; - /** - * Invert Output - * @description Off: white on black / On: black on white - * @default true - */ - invert_output?: boolean; - /** - * Prompt - * @description The prompt with which to create a mask - */ - prompt?: string; - /** - * Smoothing - * @description Radius of blur to apply before thresholding - * @default 4 - */ - smoothing?: number; - /** - * Subject Threshold - * @description Threshold above which is considered the subject - * @default 0.4 - */ - subject_threshold?: number; - /** - * Background Threshold - * @description Threshold below which is considered the background - * @default 0.4 - */ - background_threshold?: number; - /** - * Mask Expand Or Contract - * @description Pixels by which to grow (or shrink) mask after thresholding - * @default 0 - */ - mask_expand_or_contract?: number; - /** - * Mask Blur - * @description Radius of blur to apply after thresholding - * @default 0 - */ - mask_blur?: number; - /** - * type - * @default txt2mask_clipseg - * @constant - */ - type: "txt2mask_clipseg"; - }; /** * TextualInversionConfig * @description Model config for textual inversion embeddings. @@ -13067,53 +11309,53 @@ export type components = { */ UIType: "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_MainModel" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict"; /** - * T2IAdapterModelFormat + * StableDiffusionXLModelFormat * @description An enumeration. * @enum {string} */ - T2IAdapterModelFormat: "diffusers"; + StableDiffusionXLModelFormat: "checkpoint" | "diffusers"; /** - * IPAdapterModelFormat + * T2IAdapterModelFormat * @description An enumeration. * @enum {string} */ - IPAdapterModelFormat: "invokeai"; + T2IAdapterModelFormat: "diffusers"; /** - * StableDiffusionXLModelFormat + * CLIPVisionModelFormat * @description An enumeration. * @enum {string} */ - StableDiffusionXLModelFormat: "checkpoint" | "diffusers"; + CLIPVisionModelFormat: "diffusers"; /** - * StableDiffusionOnnxModelFormat + * ControlNetModelFormat * @description An enumeration. * @enum {string} */ - StableDiffusionOnnxModelFormat: "olive" | "onnx"; + ControlNetModelFormat: "checkpoint" | "diffusers"; /** - * CLIPVisionModelFormat + * StableDiffusion1ModelFormat * @description An enumeration. * @enum {string} */ - CLIPVisionModelFormat: "diffusers"; + StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; /** - * StableDiffusion2ModelFormat + * IPAdapterModelFormat * @description An enumeration. * @enum {string} */ - StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; + IPAdapterModelFormat: "invokeai"; /** - * StableDiffusion1ModelFormat + * StableDiffusion2ModelFormat * @description An enumeration. * @enum {string} */ - StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; + StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; /** - * ControlNetModelFormat + * StableDiffusionOnnxModelFormat * @description An enumeration. * @enum {string} */ - ControlNetModelFormat: "checkpoint" | "diffusers"; + StableDiffusionOnnxModelFormat: "olive" | "onnx"; }; responses: never; parameters: never; From f5447cdc23f86a9642f47e1e98f5ac35056c1943 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 13 Feb 2024 16:30:00 +1100 Subject: [PATCH 068/100] feat(ui): workflow schema v3 (WIP) The changes aim to deduplicate data between workflows and node templates, decoupling workflows from internal implementation details. A good amount of data that was needlessly duplicated from the node template to the workflow is removed. These changes substantially reduce the file size of workflows (and therefore the images with embedded workflows): - Default T2I SD1.5 workflow JSON is reduced from 23.7kb (798 lines) to 10.9kb (407 lines). - Default tiled upscale workflow JSON is reduced from 102.7kb (3341 lines) to 51.9kb (1774 lines). The trade-off is that we need to reference node templates to get things like the field type and other things. In practice, this is a non-issue, because we need a node template to do anything with a node anyways. - Field types are not included in the workflow. They are always pulled from the node templates. The field type is now properly an internal implementation detail and we can change it as needed. Previously this would require a migration for the workflow itself. With the v3 schema, the structure of a field type is an internal implementation detail that we are free to change as we see fit. - Workflow nodes no long have an `outputs` property and there is no longer such a thing as a `FieldOutputInstance`. These are only on the templates. These were never referenced at a time when we didn't also have the templates available, and there'd be no reason to do so. - Node width and height are no longer stored in the node. These weren't used. Also, per https://reactflow.dev/api-reference/types/node, we shouldn't be programmatically changing these properties. A future enhancement can properly add node resizing. - `nodeTemplates` slice is merged back into `nodesSlice` as `nodes.templates`. Turns out it's just a hassle having these separate in separate slices. - Workflow migration logic updated to support the new schema. V1 workflows migrate all the way to v3 now. - Changes throughout the nodes code to accommodate the above changes. --- .../middleware/devtools/actionSanitizer.ts | 2 +- .../listeners/getOpenAPISchema.ts | 2 +- .../listeners/updateAllNodesRequested.ts | 3 +- .../listeners/workflowLoadRequested.ts | 2 +- invokeai/frontend/web/src/app/store/store.ts | 2 - .../frontend/web/src/app/store/storeHooks.ts | 3 +- invokeai/frontend/web/src/app/store/util.ts | 2 + .../src/common/hooks/useIsReadyToEnqueue.ts | 6 +- .../flow/AddNodePopover/AddNodePopover.tsx | 14 +- .../flow/edges/util/makeEdgeSelector.ts | 18 +- .../InvocationNodeCollapsedHandles.tsx | 19 +- .../Invocation/InvocationNodeWrapper.tsx | 4 +- .../Invocation/fields/EditableFieldTitle.tsx | 4 +- .../nodes/Invocation/fields/FieldTitle.tsx | 2 +- .../Invocation/fields/FieldTooltipContent.tsx | 6 +- .../nodes/Invocation/fields/InputField.tsx | 6 +- .../Invocation/fields/InputFieldRenderer.tsx | 29 +- .../Invocation/fields/LinearViewField.tsx | 4 +- .../nodes/Invocation/fields/OutputField.tsx | 10 +- .../inspector/InspectorDetailsTab.tsx | 5 +- .../inspector/InspectorOutputsTab.tsx | 5 +- .../inspector/InspectorTemplateTab.tsx | 5 +- .../hooks/useAnyOrDirectInputFieldNames.ts | 20 +- .../src/features/nodes/hooks/useBuildNode.ts | 2 +- .../hooks/useConnectionInputFieldNames.ts | 20 +- .../nodes/hooks/useConnectionState.ts | 10 +- .../nodes/hooks/useDoNodeVersionsMatch.ts | 18 +- .../nodes/hooks/useDoesInputHaveValue.ts | 12 +- .../src/features/nodes/hooks/useFieldData.ts | 23 - .../nodes/hooks/useFieldInputInstance.ts | 15 +- .../features/nodes/hooks/useFieldInputKind.ts | 15 +- .../nodes/hooks/useFieldInputTemplate.ts | 15 +- .../src/features/nodes/hooks/useFieldLabel.ts | 10 +- .../nodes/hooks/useFieldOutputInstance.ts | 23 - .../nodes/hooks/useFieldOutputTemplate.ts | 15 +- .../features/nodes/hooks/useFieldTemplate.ts | 21 +- .../nodes/hooks/useFieldTemplateTitle.ts | 16 +- .../features/nodes/hooks/useFieldType.ts.ts | 14 +- .../nodes/hooks/useGetNodesNeedUpdate.ts | 5 +- .../features/nodes/hooks/useHasImageOutput.ts | 13 +- .../features/nodes/hooks/useIsIntermediate.ts | 10 +- .../nodes/hooks/useIsValidConnection.ts | 38 +- .../nodes/hooks/useNodeClassification.ts | 17 +- .../src/features/nodes/hooks/useNodeData.ts | 7 +- .../src/features/nodes/hooks/useNodeLabel.ts | 9 +- .../nodes/hooks/useNodeNeedsUpdate.ts | 15 +- .../src/features/nodes/hooks/useNodePack.ts | 10 +- .../features/nodes/hooks/useNodeTemplate.ts | 13 +- .../nodes/hooks/useNodeTemplateByType.ts | 10 +- .../nodes/hooks/useNodeTemplateTitle.ts | 15 +- .../nodes/hooks/useOutputFieldNames.ts | 20 +- .../src/features/nodes/hooks/useUseCache.ts | 8 +- .../nodes/hooks/useWorkflowWatcher.ts | 4 +- .../web/src/features/nodes/store/actions.ts | 4 +- .../nodes/store/nodeTemplatesSlice.ts | 24 - .../src/features/nodes/store/nodesSlice.ts | 15 +- .../web/src/features/nodes/store/selectors.ts | 51 + .../web/src/features/nodes/store/types.ts | 5 +- .../store/util/findConnectionToValidHandle.ts | 30 +- .../util/makeIsConnectionValidSelector.ts | 2 +- .../src/features/nodes/store/workflowSlice.ts | 6 +- .../web/src/features/nodes/types/field.ts | 130 +-- .../src/features/nodes/types/invocation.ts | 22 +- .../web/src/features/nodes/types/v2/common.ts | 188 ++++ .../src/features/nodes/types/v2/constants.ts | 80 ++ .../web/src/features/nodes/types/v2/error.ts | 58 ++ .../web/src/features/nodes/types/v2/field.ts | 875 ++++++++++++++++++ .../src/features/nodes/types/v2/invocation.ts | 93 ++ .../src/features/nodes/types/v2/metadata.ts | 77 ++ .../src/features/nodes/types/v2/openapi.ts | 86 ++ .../web/src/features/nodes/types/v2/semver.ts | 21 + .../src/features/nodes/types/v2/workflow.ts | 89 ++ .../web/src/features/nodes/types/workflow.ts | 10 +- .../nodes/util/node/buildInvocationNode.ts | 22 +- .../features/nodes/util/node/nodeUpdate.ts | 1 - .../util/schema/buildFieldInputInstance.ts | 3 - .../nodes/util/workflow/buildWorkflow.ts | 20 +- .../nodes/util/workflow/migrations.ts | 32 +- .../nodes/util/workflow/validateWorkflow.ts | 4 +- .../workflowLibrary/hooks/useSaveWorkflow.ts | 4 +- 80 files changed, 1936 insertions(+), 612 deletions(-) create mode 100644 invokeai/frontend/web/src/app/store/util.ts delete mode 100644 invokeai/frontend/web/src/features/nodes/hooks/useFieldData.ts delete mode 100644 invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputInstance.ts delete mode 100644 invokeai/frontend/web/src/features/nodes/store/nodeTemplatesSlice.ts create mode 100644 invokeai/frontend/web/src/features/nodes/store/selectors.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/v2/common.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/v2/constants.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/v2/error.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/v2/field.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/v2/invocation.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/v2/metadata.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/v2/openapi.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/v2/semver.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/v2/workflow.ts diff --git a/invokeai/frontend/web/src/app/store/middleware/devtools/actionSanitizer.ts b/invokeai/frontend/web/src/app/store/middleware/devtools/actionSanitizer.ts index 2e2d2014b23..ed8c82d91ca 100644 --- a/invokeai/frontend/web/src/app/store/middleware/devtools/actionSanitizer.ts +++ b/invokeai/frontend/web/src/app/store/middleware/devtools/actionSanitizer.ts @@ -1,6 +1,6 @@ import type { UnknownAction } from '@reduxjs/toolkit'; import { isAnyGraphBuilt } from 'features/nodes/store/actions'; -import { nodeTemplatesBuilt } from 'features/nodes/store/nodeTemplatesSlice'; +import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice'; import { cloneDeep } from 'lodash-es'; import { appInfoApi } from 'services/api/endpoints/appInfo'; import type { Graph } from 'services/api/types'; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema.ts index b2d36159098..88518e2c0bb 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema.ts @@ -1,6 +1,6 @@ import { logger } from 'app/logging/logger'; import { parseify } from 'common/util/serialize'; -import { nodeTemplatesBuilt } from 'features/nodes/store/nodeTemplatesSlice'; +import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice'; import { parseSchema } from 'features/nodes/util/schema/parseSchema'; import { size } from 'lodash-es'; import { appInfoApi } from 'services/api/endpoints/appInfo'; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts index 752c3b09df2..ac1298da5ba 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts @@ -15,8 +15,7 @@ export const addUpdateAllNodesRequestedListener = () => { actionCreator: updateAllNodesRequested, effect: (action, { dispatch, getState }) => { const log = logger('nodes'); - const nodes = getState().nodes.nodes; - const templates = getState().nodeTemplates.templates; + const { nodes, templates } = getState().nodes; let unableToUpdateCount = 0; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts index 9307031e6d0..ad41dc2654f 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts @@ -18,7 +18,7 @@ export const addWorkflowLoadRequestedListener = () => { effect: (action, { dispatch, getState }) => { const log = logger('nodes'); const { workflow, asCopy } = action.payload; - const nodeTemplates = getState().nodeTemplates.templates; + const nodeTemplates = getState().nodes.templates; try { const { workflow: validatedWorkflow, warnings } = validateWorkflow(workflow, nodeTemplates); diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts index e25e1351eb9..270662c3d21 100644 --- a/invokeai/frontend/web/src/app/store/store.ts +++ b/invokeai/frontend/web/src/app/store/store.ts @@ -16,7 +16,6 @@ import { hrfPersistConfig, hrfSlice } from 'features/hrf/store/hrfSlice'; import { loraPersistConfig, loraSlice } from 'features/lora/store/loraSlice'; import { modelManagerPersistConfig, modelManagerSlice } from 'features/modelManager/store/modelManagerSlice'; import { nodesPersistConfig, nodesSlice } from 'features/nodes/store/nodesSlice'; -import { nodesTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; import { workflowPersistConfig, workflowSlice } from 'features/nodes/store/workflowSlice'; import { generationPersistConfig, generationSlice } from 'features/parameters/store/generationSlice'; import { postprocessingPersistConfig, postprocessingSlice } from 'features/parameters/store/postprocessingSlice'; @@ -46,7 +45,6 @@ const allReducers = { [gallerySlice.name]: gallerySlice.reducer, [generationSlice.name]: generationSlice.reducer, [nodesSlice.name]: nodesSlice.reducer, - [nodesTemplatesSlice.name]: nodesTemplatesSlice.reducer, [postprocessingSlice.name]: postprocessingSlice.reducer, [systemSlice.name]: systemSlice.reducer, [configSlice.name]: configSlice.reducer, diff --git a/invokeai/frontend/web/src/app/store/storeHooks.ts b/invokeai/frontend/web/src/app/store/storeHooks.ts index f1a9aa979c0..6bc904acb31 100644 --- a/invokeai/frontend/web/src/app/store/storeHooks.ts +++ b/invokeai/frontend/web/src/app/store/storeHooks.ts @@ -1,7 +1,8 @@ import type { AppThunkDispatch, RootState } from 'app/store/store'; import type { TypedUseSelectorHook } from 'react-redux'; -import { useDispatch, useSelector } from 'react-redux'; +import { useDispatch, useSelector, useStore } from 'react-redux'; // Use throughout your app instead of plain `useDispatch` and `useSelector` export const useAppDispatch = () => useDispatch(); export const useAppSelector: TypedUseSelectorHook = useSelector; +export const useAppStore = () => useStore(); diff --git a/invokeai/frontend/web/src/app/store/util.ts b/invokeai/frontend/web/src/app/store/util.ts new file mode 100644 index 00000000000..381f7f85d26 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/util.ts @@ -0,0 +1,2 @@ +export const EMPTY_ARRAY = []; +export const EMPTY_OBJECT = {}; diff --git a/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts b/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts index 4952fa1c47b..baa704e75ca 100644 --- a/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts +++ b/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts @@ -8,7 +8,6 @@ import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types'; import { selectDynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice'; import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { selectGenerationSlice } from 'features/parameters/store/generationSlice'; import { selectSystemSlice } from 'features/system/store/systemSlice'; @@ -23,11 +22,10 @@ const selector = createMemoizedSelector( selectGenerationSlice, selectSystemSlice, selectNodesSlice, - selectNodeTemplatesSlice, selectDynamicPromptsSlice, activeTabNameSelector, ], - (controlAdapters, generation, system, nodes, nodeTemplates, dynamicPrompts, activeTabName) => { + (controlAdapters, generation, system, nodes, dynamicPrompts, activeTabName) => { const { initialImage, model, positivePrompt } = generation; const { isConnected } = system; @@ -54,7 +52,7 @@ const selector = createMemoizedSelector( return; } - const nodeTemplate = nodeTemplates.templates[node.data.type]; + const nodeTemplate = nodes.templates[node.data.type]; if (!nodeTemplate) { // Node type not found diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx index b24b52c6abf..061209cafc0 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx @@ -7,8 +7,12 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import type { SelectInstance } from 'chakra-react-select'; import { useBuildNode } from 'features/nodes/hooks/useBuildNode'; -import { addNodePopoverClosed, addNodePopoverOpened, nodeAdded } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; +import { + addNodePopoverClosed, + addNodePopoverOpened, + nodeAdded, + selectNodesSlice, +} from 'features/nodes/store/nodesSlice'; import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes'; import { filter, map, memoize, some } from 'lodash-es'; import type { KeyboardEventHandler } from 'react'; @@ -54,10 +58,10 @@ const AddNodePopover = () => { const fieldFilter = useAppSelector((s) => s.nodes.connectionStartFieldType); const handleFilter = useAppSelector((s) => s.nodes.connectionStartParams?.handleType); - const selector = createMemoizedSelector(selectNodeTemplatesSlice, (nodeTemplates) => { + const selector = createMemoizedSelector(selectNodesSlice, (nodes) => { // If we have a connection in progress, we need to filter the node choices const filteredNodeTemplates = fieldFilter - ? filter(nodeTemplates.templates, (template) => { + ? filter(nodes.templates, (template) => { const handles = handleFilter === 'source' ? template.inputs : template.outputs; return some(handles, (handle) => { @@ -67,7 +71,7 @@ const AddNodePopover = () => { return validateSourceAndTargetTypes(sourceType, targetType); }); }) - : map(nodeTemplates.templates); + : map(nodes.templates); const options: ComboboxOption[] = map(filteredNodeTemplates, (template) => { return { diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts index 4bfc588e675..ba40b4984cd 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts @@ -1,10 +1,17 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; +import { selectFieldOutputTemplate } from 'features/nodes/store/selectors'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { getFieldColor } from './getEdgeColor'; +const defaultReturnValue = { + isSelected: false, + shouldAnimate: false, + stroke: colorTokenToCssVar('base.500'), +}; + export const makeEdgeSelector = ( source: string, sourceHandleId: string | null | undefined, @@ -12,14 +19,19 @@ export const makeEdgeSelector = ( targetHandleId: string | null | undefined, selected?: boolean ) => - createMemoizedSelector(selectNodesSlice, (nodes) => { + createMemoizedSelector(selectNodesSlice, (nodes): { isSelected: boolean; shouldAnimate: boolean; stroke: string } => { const sourceNode = nodes.nodes.find((node) => node.id === source); const targetNode = nodes.nodes.find((node) => node.id === target); const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode); - const isSelected = sourceNode?.selected || targetNode?.selected || selected; - const sourceType = isInvocationToInvocationEdge ? sourceNode?.data?.outputs[sourceHandleId || '']?.type : undefined; + const isSelected = Boolean(sourceNode?.selected || targetNode?.selected || selected); + if (!sourceNode || !sourceHandleId) { + return defaultReturnValue; + } + + const outputFieldTemplate = selectFieldOutputTemplate(nodes, sourceNode.id, sourceHandleId); + const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined; const stroke = sourceType && nodes.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500'); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeCollapsedHandles.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeCollapsedHandles.tsx index c287842f6ed..b888e8a5162 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeCollapsedHandles.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeCollapsedHandles.tsx @@ -1,6 +1,5 @@ import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens'; -import { useNodeData } from 'features/nodes/hooks/useNodeData'; -import { isInvocationNodeData } from 'features/nodes/types/invocation'; +import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate'; import { map } from 'lodash-es'; import type { CSSProperties } from 'react'; import { memo, useMemo } from 'react'; @@ -13,7 +12,7 @@ interface Props { const hiddenHandleStyles: CSSProperties = { visibility: 'hidden' }; const InvocationNodeCollapsedHandles = ({ nodeId }: Props) => { - const data = useNodeData(nodeId); + const template = useNodeTemplate(nodeId); const { base600 } = useChakraThemeTokens(); const dummyHandleStyles: CSSProperties = useMemo( @@ -37,7 +36,7 @@ const InvocationNodeCollapsedHandles = ({ nodeId }: Props) => { [dummyHandleStyles] ); - if (!isInvocationNodeData(data)) { + if (!template) { return null; } @@ -45,14 +44,14 @@ const InvocationNodeCollapsedHandles = ({ nodeId }: Props) => { <> - {map(data.inputs, (input) => ( + {map(template.inputs, (input) => ( { ))} - {map(data.outputs, (output) => ( + {map(template.outputs, (output) => ( ) => { const { id: nodeId, type, isOpen, label } = data; const hasTemplateSelector = useMemo( - () => createSelector(selectNodeTemplatesSlice, (nodeTemplates) => Boolean(nodeTemplates.templates[type])), + () => createSelector(selectNodesSlice, (nodes) => Boolean(nodes.templates[type])), [type] ); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/EditableFieldTitle.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/EditableFieldTitle.tsx index c2231f703ab..e02b1a1474e 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/EditableFieldTitle.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/EditableFieldTitle.tsx @@ -22,7 +22,7 @@ import FieldTooltipContent from './FieldTooltipContent'; interface Props { nodeId: string; fieldName: string; - kind: 'input' | 'output'; + kind: 'inputs' | 'outputs'; isMissingInput?: boolean; withTooltip?: boolean; } @@ -58,7 +58,7 @@ const EditableFieldTitle = forwardRef((props: Props, ref) => { return ( : undefined} + label={withTooltip ? : undefined} openDelay={HANDLE_TOOLTIP_OPEN_DELAY} > { - const field = useFieldInstance(nodeId, fieldName); + const field = useFieldInputInstance(nodeId, fieldName); const fieldTemplate = useFieldTemplate(nodeId, fieldName, kind); const isInputTemplate = isFieldInputTemplate(fieldTemplate); const fieldTypeName = useFieldTypeName(fieldTemplate?.type); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx index 2b9f7960e4b..66b0d3f7556 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx @@ -25,7 +25,7 @@ const InputField = ({ nodeId, fieldName }: Props) => { const [isHovered, setIsHovered] = useState(false); const { isConnected, isConnectionInProgress, isConnectionStartField, connectionError, shouldDim } = - useConnectionState({ nodeId, fieldName, kind: 'input' }); + useConnectionState({ nodeId, fieldName, kind: 'inputs' }); const isMissingInput = useMemo(() => { if (!fieldTemplate) { @@ -76,7 +76,7 @@ const InputField = ({ nodeId, fieldName }: Props) => { @@ -101,7 +101,7 @@ const InputField = ({ nodeId, fieldName }: Props) => { diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx index c1d52c1d4fb..b6e331c1149 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx @@ -1,6 +1,5 @@ -import { Box, Text } from '@invoke-ai/ui-library'; -import { useFieldInstance } from 'features/nodes/hooks/useFieldData'; -import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate'; +import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance'; +import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate'; import { isBoardFieldInputInstance, isBoardFieldInputTemplate, @@ -38,7 +37,6 @@ import { isVAEModelFieldInputTemplate, } from 'features/nodes/types/field'; import { memo } from 'react'; -import { useTranslation } from 'react-i18next'; import BoardFieldInputComponent from './inputs/BoardFieldInputComponent'; import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent'; @@ -63,17 +61,8 @@ type InputFieldProps = { }; const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => { - const { t } = useTranslation(); - const fieldInstance = useFieldInstance(nodeId, fieldName); - const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input'); - - if (fieldTemplate?.fieldKind === 'output') { - return ( - - {t('nodes.outputFieldInInput')}: {fieldInstance?.type.name} - - ); - } + const fieldInstance = useFieldInputInstance(nodeId, fieldName); + const fieldTemplate = useFieldInputTemplate(nodeId, fieldName); if (isStringFieldInputInstance(fieldInstance) && isStringFieldInputTemplate(fieldTemplate)) { return ; @@ -141,18 +130,10 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => { return ; } - if (fieldInstance && fieldTemplate) { + if (fieldTemplate) { // Fallback for when there is no component for the type return null; } - - return ( - - - {t('nodes.unknownFieldType', { type: fieldInstance?.type.name })} - - - ); }; export default memo(InputFieldRenderer); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx index d0a30ecc3c7..0cd199f7a47 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx @@ -62,7 +62,7 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => { /> - + {isValueChanged && ( { /> )} } + label={} openDelay={HANDLE_TOOLTIP_OPEN_DELAY} placement="top" > diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputField.tsx index 48c4c0d7404..f2d776a2da1 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputField.tsx @@ -1,6 +1,5 @@ import { Flex, FormControl, FormLabel, Tooltip } from '@invoke-ai/ui-library'; import { useConnectionState } from 'features/nodes/hooks/useConnectionState'; -import { useFieldOutputInstance } from 'features/nodes/hooks/useFieldOutputInstance'; import { useFieldOutputTemplate } from 'features/nodes/hooks/useFieldOutputTemplate'; import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants'; import type { PropsWithChildren } from 'react'; @@ -18,18 +17,17 @@ interface Props { const OutputField = ({ nodeId, fieldName }: Props) => { const { t } = useTranslation(); const fieldTemplate = useFieldOutputTemplate(nodeId, fieldName); - const fieldInstance = useFieldOutputInstance(nodeId, fieldName); const { isConnected, isConnectionInProgress, isConnectionStartField, connectionError, shouldDim } = - useConnectionState({ nodeId, fieldName, kind: 'output' }); + useConnectionState({ nodeId, fieldName, kind: 'outputs' }); - if (!fieldTemplate || !fieldInstance) { + if (!fieldTemplate) { return ( {t('nodes.unknownOutput', { - name: fieldTemplate?.title ?? fieldName, + name: fieldName, })} @@ -40,7 +38,7 @@ const OutputField = ({ nodeId, fieldName }: Props) => { return ( } + label={} openDelay={HANDLE_TOOLTIP_OPEN_DELAY} placement="top" shouldWrapChildren diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx index b7c9033d6b2..d72d2f5aa8d 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx @@ -6,19 +6,18 @@ import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableCon import NotesTextarea from 'features/nodes/components/flow/nodes/Invocation/NotesTextarea'; import { useNodeNeedsUpdate } from 'features/nodes/hooks/useNodeNeedsUpdate'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; import EditableNodeTitle from './details/EditableNodeTitle'; -const selector = createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { +const selector = createMemoizedSelector(selectNodesSlice, (nodes) => { const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1]; const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId); - const lastSelectedNodeTemplate = lastSelectedNode ? nodeTemplates.templates[lastSelectedNode.data.type] : undefined; + const lastSelectedNodeTemplate = lastSelectedNode ? nodes.templates[lastSelectedNode.data.type] : undefined; if (!isInvocationNode(lastSelectedNode) || !lastSelectedNodeTemplate) { return; diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx index ee7dfaa6932..978eeddd24a 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx @@ -5,7 +5,6 @@ import { IAINoContentFallback } from 'common/components/IAIImageFallback'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -14,12 +13,12 @@ import type { AnyResult } from 'services/events/types'; import ImageOutputPreview from './outputs/ImageOutputPreview'; -const selector = createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { +const selector = createMemoizedSelector(selectNodesSlice, (nodes) => { const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1]; const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId); - const lastSelectedNodeTemplate = lastSelectedNode ? nodeTemplates.templates[lastSelectedNode.data.type] : undefined; + const lastSelectedNodeTemplate = lastSelectedNode ? nodes.templates[lastSelectedNode.data.type] : undefined; const nes = nodes.nodeExecutionStates[lastSelectedNodeId ?? '__UNKNOWN_NODE__']; diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorTemplateTab.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorTemplateTab.tsx index 28f0e82d68c..ea6e8ed704d 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorTemplateTab.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorTemplateTab.tsx @@ -3,16 +3,15 @@ import { useAppSelector } from 'app/store/storeHooks'; import { IAINoContentFallback } from 'common/components/IAIImageFallback'; import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; -const selector = createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { +const selector = createMemoizedSelector(selectNodesSlice, (nodes) => { const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1]; const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId); - const lastSelectedNodeTemplate = lastSelectedNode ? nodeTemplates.templates[lastSelectedNode.data.type] : undefined; + const lastSelectedNodeTemplate = lastSelectedNode ? nodes.templates[lastSelectedNode.data.type] : undefined; return { template: lastSelectedNodeTemplate, diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts index d0263a8bdaf..c882924e241 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts @@ -1,26 +1,22 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; +import { EMPTY_ARRAY } from 'app/store/util'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeTemplate } from 'features/nodes/store/selectors'; import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames'; import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate'; import { keys, map } from 'lodash-es'; import { useMemo } from 'react'; -export const useAnyOrDirectInputFieldNames = (nodeId: string) => { +export const useAnyOrDirectInputFieldNames = (nodeId: string): string[] => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return []; + createMemoizedSelector(selectNodesSlice, (nodes) => { + const template = selectNodeTemplate(nodes, nodeId); + if (!template) { + return EMPTY_ARRAY; } - const nodeTemplate = nodeTemplates.templates[node.data.type]; - if (!nodeTemplate) { - return []; - } - const fields = map(nodeTemplate.inputs).filter( + const fields = map(template.inputs).filter( (field) => (['any', 'direct'].includes(field.input) || field.type.isCollectionOrScalar) && keys(TEMPLATE_BUILDER_MAP).includes(field.type.name) diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts b/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts index aecc9318938..b19edf3c85a 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts @@ -13,7 +13,7 @@ export const SHARED_NODE_PROPERTIES: Partial = { }; export const useBuildNode = () => { - const nodeTemplates = useAppSelector((s) => s.nodeTemplates.templates); + const nodeTemplates = useAppSelector((s) => s.nodes.templates); const flow = useReactFlow(); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts index 23f318517b5..dc8a05b88c2 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts @@ -1,28 +1,24 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; +import { EMPTY_ARRAY } from 'app/store/util'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeTemplate } from 'features/nodes/store/selectors'; import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames'; import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate'; import { keys, map } from 'lodash-es'; import { useMemo } from 'react'; -export const useConnectionInputFieldNames = (nodeId: string) => { +export const useConnectionInputFieldNames = (nodeId: string): string[] => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return []; - } - const nodeTemplate = nodeTemplates.templates[node.data.type]; - if (!nodeTemplate) { - return []; + createMemoizedSelector(selectNodesSlice, (nodes) => { + const template = selectNodeTemplate(nodes, nodeId); + if (!template) { + return EMPTY_ARRAY; } // get the visible fields - const fields = map(nodeTemplate.inputs).filter( + const fields = map(template.inputs).filter( (field) => (field.input === 'connection' && !field.type.isCollectionOrScalar) || !keys(TEMPLATE_BUILDER_MAP).includes(field.type.name) diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts index a6f8b663f69..97b96f323ad 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts @@ -14,7 +14,7 @@ const selectIsConnectionInProgress = createSelector( export type UseConnectionStateProps = { nodeId: string; fieldName: string; - kind: 'input' | 'output'; + kind: 'inputs' | 'outputs'; }; export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionStateProps) => { @@ -26,8 +26,8 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta Boolean( nodes.edges.filter((edge) => { return ( - (kind === 'input' ? edge.target : edge.source) === nodeId && - (kind === 'input' ? edge.targetHandle : edge.sourceHandle) === fieldName + (kind === 'inputs' ? edge.target : edge.source) === nodeId && + (kind === 'inputs' ? edge.targetHandle : edge.sourceHandle) === fieldName ); }).length ) @@ -36,7 +36,7 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta ); const selectConnectionError = useMemo( - () => makeConnectionErrorSelector(nodeId, fieldName, kind === 'input' ? 'target' : 'source', fieldType), + () => makeConnectionErrorSelector(nodeId, fieldName, kind === 'inputs' ? 'target' : 'source', fieldType), [nodeId, fieldName, kind, fieldType] ); @@ -46,7 +46,7 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta Boolean( nodes.connectionStartParams?.nodeId === nodeId && nodes.connectionStartParams?.handleId === fieldName && - nodes.connectionStartParams?.handleType === { input: 'target', output: 'source' }[kind] + nodes.connectionStartParams?.handleType === { inputs: 'target', outputs: 'source' }[kind] ) ), [fieldName, kind, nodeId] diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useDoNodeVersionsMatch.ts b/invokeai/frontend/web/src/features/nodes/hooks/useDoNodeVersionsMatch.ts index bfbf0a3b2d3..91994cf7525 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useDoNodeVersionsMatch.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useDoNodeVersionsMatch.ts @@ -2,23 +2,19 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { compareVersions } from 'compare-versions'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeData, selectNodeTemplate } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; -export const useDoNodeVersionsMatch = (nodeId: string) => { +export const useDoNodeVersionsMatch = (nodeId: string): boolean => { const selector = useMemo( () => - createSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { + createSelector(selectNodesSlice, (nodes) => { + const data = selectNodeData(nodes, nodeId); + const template = selectNodeTemplate(nodes, nodeId); + if (!template?.version || !data?.version) { return false; } - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - if (!nodeTemplate?.version || !node.data?.version) { - return false; - } - return compareVersions(nodeTemplate.version, node.data.version) === 0; + return compareVersions(template.version, data.version) === 0; }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useDoesInputHaveValue.ts b/invokeai/frontend/web/src/features/nodes/hooks/useDoesInputHaveValue.ts index cfe5c90d9cc..5051eaa55b3 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useDoesInputHaveValue.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useDoesInputHaveValue.ts @@ -1,18 +1,18 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeData } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; -export const useDoesInputHaveValue = (nodeId: string, fieldName: string) => { +export const useDoesInputHaveValue = (nodeId: string, fieldName: string): boolean => { const selector = useMemo( () => createMemoizedSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; + const data = selectNodeData(nodes, nodeId); + if (!data) { + return false; } - return node?.data.inputs[fieldName]?.value !== undefined; + return data.inputs[fieldName]?.value !== undefined; }), [fieldName, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldData.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldData.ts deleted file mode 100644 index 8b35a2d44be..00000000000 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldData.ts +++ /dev/null @@ -1,23 +0,0 @@ -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; -import { useAppSelector } from 'app/store/storeHooks'; -import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; -import { useMemo } from 'react'; - -export const useFieldInstance = (nodeId: string, fieldName: string) => { - const selector = useMemo( - () => - createMemoizedSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; - } - return node?.data.inputs[fieldName]; - }), - [fieldName, nodeId] - ); - - const fieldData = useAppSelector(selector); - - return fieldData; -}; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputInstance.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputInstance.ts index 0793f1f9529..25065e7aba5 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputInstance.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputInstance.ts @@ -1,23 +1,20 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldInputInstance } from 'features/nodes/store/selectors'; +import type { FieldInputInstance } from 'features/nodes/types/field'; import { useMemo } from 'react'; -export const useFieldInputInstance = (nodeId: string, fieldName: string) => { +export const useFieldInputInstance = (nodeId: string, fieldName: string): FieldInputInstance | null => { const selector = useMemo( () => createMemoizedSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; - } - return node.data.inputs[fieldName]; + return selectFieldInputInstance(nodes, nodeId, fieldName); }), [fieldName, nodeId] ); - const fieldTemplate = useAppSelector(selector); + const fieldData = useAppSelector(selector); - return fieldTemplate; + return fieldData; }; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputKind.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputKind.ts index 11d44dbde2e..08de3d9b205 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputKind.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputKind.ts @@ -1,21 +1,16 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldInputTemplate } from 'features/nodes/store/selectors'; +import type { FieldInput } from 'features/nodes/types/field'; import { useMemo } from 'react'; export const useFieldInputKind = (nodeId: string, fieldName: string) => { const selector = useMemo( () => - createSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; - } - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - const fieldTemplate = nodeTemplate?.inputs[fieldName]; - return fieldTemplate?.input; + createSelector(selectNodesSlice, (nodes): FieldInput | null => { + const template = selectFieldInputTemplate(nodes, nodeId, fieldName); + return template?.input ?? null; }), [fieldName, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputTemplate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputTemplate.ts index 8533d2be8df..e8289d7e07d 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputTemplate.ts @@ -1,20 +1,15 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldInputTemplate } from 'features/nodes/store/selectors'; +import type { FieldInputTemplate } from 'features/nodes/types/field'; import { useMemo } from 'react'; -export const useFieldInputTemplate = (nodeId: string, fieldName: string) => { +export const useFieldInputTemplate = (nodeId: string, fieldName: string): FieldInputTemplate | null => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; - } - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - return nodeTemplate?.inputs[fieldName]; + createMemoizedSelector(selectNodesSlice, (nodes) => { + return selectFieldInputTemplate(nodes, nodeId, fieldName); }), [fieldName, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldLabel.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldLabel.ts index ef57956047e..92eab8d1b15 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldLabel.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldLabel.ts @@ -1,18 +1,14 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldInputInstance } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; -export const useFieldLabel = (nodeId: string, fieldName: string) => { +export const useFieldLabel = (nodeId: string, fieldName: string): string | null => { const selector = useMemo( () => createSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; - } - return node?.data.inputs[fieldName]?.label; + return selectFieldInputInstance(nodes, nodeId, fieldName)?.label ?? null; }), [fieldName, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputInstance.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputInstance.ts deleted file mode 100644 index 8b71f1ea014..00000000000 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputInstance.ts +++ /dev/null @@ -1,23 +0,0 @@ -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; -import { useAppSelector } from 'app/store/storeHooks'; -import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; -import { useMemo } from 'react'; - -export const useFieldOutputInstance = (nodeId: string, fieldName: string) => { - const selector = useMemo( - () => - createMemoizedSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; - } - return node.data.outputs[fieldName]; - }), - [fieldName, nodeId] - ); - - const fieldTemplate = useAppSelector(selector); - - return fieldTemplate; -}; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputTemplate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputTemplate.ts index 11f592b399e..cb154071e97 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputTemplate.ts @@ -1,20 +1,15 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldOutputTemplate } from 'features/nodes/store/selectors'; +import type { FieldOutputTemplate } from 'features/nodes/types/field'; import { useMemo } from 'react'; -export const useFieldOutputTemplate = (nodeId: string, fieldName: string) => { +export const useFieldOutputTemplate = (nodeId: string, fieldName: string): FieldOutputTemplate | null => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; - } - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - return nodeTemplate?.outputs[fieldName]; + createMemoizedSelector(selectNodesSlice, (nodes) => { + return selectFieldOutputTemplate(nodes, nodeId, fieldName); }), [fieldName, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplate.ts index 663821da81e..7be4ecfd4df 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplate.ts @@ -1,21 +1,22 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { KIND_MAP } from 'features/nodes/types/constants'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldInputTemplate, selectFieldOutputTemplate } from 'features/nodes/store/selectors'; +import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field'; import { useMemo } from 'react'; -export const useFieldTemplate = (nodeId: string, fieldName: string, kind: 'input' | 'output') => { +export const useFieldTemplate = ( + nodeId: string, + fieldName: string, + kind: 'inputs' | 'outputs' +): FieldInputTemplate | FieldOutputTemplate | null => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; + createMemoizedSelector(selectNodesSlice, (nodes) => { + if (kind === 'inputs') { + return selectFieldInputTemplate(nodes, nodeId, fieldName); } - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - return nodeTemplate?.[KIND_MAP[kind]][fieldName]; + return selectFieldOutputTemplate(nodes, nodeId, fieldName); }), [fieldName, kind, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplateTitle.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplateTitle.ts index cfdcda6efab..e41e0195724 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplateTitle.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplateTitle.ts @@ -1,21 +1,17 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { KIND_MAP } from 'features/nodes/types/constants'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldInputTemplate, selectFieldOutputTemplate } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; -export const useFieldTemplateTitle = (nodeId: string, fieldName: string, kind: 'input' | 'output') => { +export const useFieldTemplateTitle = (nodeId: string, fieldName: string, kind: 'inputs' | 'outputs'): string | null => { const selector = useMemo( () => - createSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; + createSelector(selectNodesSlice, (nodes) => { + if (kind === 'inputs') { + return selectFieldInputTemplate(nodes, nodeId, fieldName)?.title ?? null; } - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - return nodeTemplate?.[KIND_MAP[kind]][fieldName]?.title; + return selectFieldOutputTemplate(nodes, nodeId, fieldName)?.title ?? null; }), [fieldName, kind, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldType.ts.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldType.ts.ts index a834726a136..a71a4d044ee 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldType.ts.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldType.ts.ts @@ -1,20 +1,18 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { KIND_MAP } from 'features/nodes/types/constants'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldInputTemplate, selectFieldOutputTemplate } from 'features/nodes/store/selectors'; +import type { FieldType } from 'features/nodes/types/field'; import { useMemo } from 'react'; -export const useFieldType = (nodeId: string, fieldName: string, kind: 'input' | 'output') => { +export const useFieldType = (nodeId: string, fieldName: string, kind: 'inputs' | 'outputs'): FieldType | null => { const selector = useMemo( () => createMemoizedSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; + if (kind === 'inputs') { + return selectFieldInputTemplate(nodes, nodeId, fieldName)?.type ?? null; } - const field = node.data[KIND_MAP[kind]][fieldName]; - return field?.type; + return selectFieldOutputTemplate(nodes, nodeId, fieldName)?.type ?? null; }), [fieldName, kind, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts index a8019c92d6d..71344197d54 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts @@ -1,13 +1,12 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate'; -const selector = createSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => +const selector = createSelector(selectNodesSlice, (nodes) => nodes.nodes.filter(isInvocationNode).some((node) => { - const template = nodeTemplates.templates[node.data.type]; + const template = nodes.templates[node.data.type]; if (!template) { return false; } diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useHasImageOutput.ts b/invokeai/frontend/web/src/features/nodes/hooks/useHasImageOutput.ts index 617e713c7cc..3ac3cabb220 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useHasImageOutput.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useHasImageOutput.ts @@ -1,24 +1,21 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeTemplate } from 'features/nodes/store/selectors'; import { some } from 'lodash-es'; import { useMemo } from 'react'; -export const useHasImageOutput = (nodeId: string) => { +export const useHasImageOutput = (nodeId: string): boolean => { const selector = useMemo( () => createMemoizedSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return false; - } + const template = selectNodeTemplate(nodes, nodeId); return some( - node.data.outputs, + template?.outputs, (output) => output.type.name === 'ImageField' && // the image primitive node (node type "image") does not actually save the image, do not show the image-saving checkboxes - node.data.type !== 'image' + template?.type !== 'image' ); }), [nodeId] diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useIsIntermediate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useIsIntermediate.ts index 729bfa0cea0..3fad0a2a861 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useIsIntermediate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useIsIntermediate.ts @@ -1,18 +1,14 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeData } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; -export const useIsIntermediate = (nodeId: string) => { +export const useIsIntermediate = (nodeId: string): boolean => { const selector = useMemo( () => createSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return false; - } - return node.data.isIntermediate; + return selectNodeData(nodes, nodeId)?.isIntermediate ?? false; }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts index 39a8abbe7a2..ded05c7b9bf 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts @@ -1,11 +1,10 @@ // TODO: enable this at some point -import { useAppSelector } from 'app/store/storeHooks'; +import { useAppSelector, useAppStore } from 'app/store/storeHooks'; import { getIsGraphAcyclic } from 'features/nodes/store/util/getIsGraphAcyclic'; import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes'; import type { InvocationNodeData } from 'features/nodes/types/invocation'; import { useCallback } from 'react'; import type { Connection, Node } from 'reactflow'; -import { useReactFlow } from 'reactflow'; /** * NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts` @@ -13,36 +12,31 @@ import { useReactFlow } from 'reactflow'; */ export const useIsValidConnection = () => { - const flow = useReactFlow(); + const store = useAppStore(); const shouldValidateGraph = useAppSelector((s) => s.nodes.shouldValidateGraph); const isValidConnection = useCallback( ({ source, sourceHandle, target, targetHandle }: Connection): boolean => { - const edges = flow.getEdges(); - const nodes = flow.getNodes(); // Connection must have valid targets if (!(source && sourceHandle && target && targetHandle)) { return false; } - // Find the source and target nodes - const sourceNode = flow.getNode(source) as Node; - const targetNode = flow.getNode(target) as Node; - - // Conditional guards against undefined nodes/handles - if (!(sourceNode && targetNode && sourceNode.data && targetNode.data)) { + if (source === target) { + // Don't allow nodes to connect to themselves, even if validation is disabled return false; } - const sourceField = sourceNode.data.outputs[sourceHandle]; - const targetField = targetNode.data.inputs[targetHandle]; + const state = store.getState(); + const { nodes, edges, templates } = state.nodes; - if (!sourceField || !targetField) { - // something has gone terribly awry - return false; - } + // Find the source and target nodes + const sourceNode = nodes.find((node) => node.id === source) as Node; + const targetNode = nodes.find((node) => node.id === target) as Node; + const sourceFieldTemplate = templates[sourceNode.data.type]?.outputs[sourceHandle]; + const targetFieldTemplate = templates[targetNode.data.type]?.inputs[targetHandle]; - if (source === target) { - // Don't allow nodes to connect to themselves, even if validation is disabled + // Conditional guards against undefined nodes/handles + if (!(sourceFieldTemplate && targetFieldTemplate)) { return false; } @@ -69,20 +63,20 @@ export const useIsValidConnection = () => { return edge.target === target && edge.targetHandle === targetHandle; }) && // except CollectionItem inputs can have multiples - targetField.type.name !== 'CollectionItemField' + targetFieldTemplate.type.name !== 'CollectionItemField' ) { return false; } // Must use the originalType here if it exists - if (!validateSourceAndTargetTypes(sourceField.type, targetField.type)) { + if (!validateSourceAndTargetTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) { return false; } // Graphs much be acyclic (no loops!) return getIsGraphAcyclic(source, target, nodes, edges); }, - [flow, shouldValidateGraph] + [shouldValidateGraph, store] ); return isValidConnection; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeClassification.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeClassification.ts index c61721030eb..bab8ff3f194 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeClassification.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeClassification.ts @@ -1,20 +1,15 @@ -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; +import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeTemplate } from 'features/nodes/store/selectors'; +import type { Classification } from 'features/nodes/types/common'; import { useMemo } from 'react'; -export const useNodeClassification = (nodeId: string) => { +export const useNodeClassification = (nodeId: string): Classification | null => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return false; - } - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - return nodeTemplate?.classification; + createSelector(selectNodesSlice, (nodes) => { + return selectNodeTemplate(nodes, nodeId)?.classification ?? null; }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeData.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeData.ts index c507def5ee3..fa21008ff8b 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeData.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeData.ts @@ -1,14 +1,15 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; +import { selectNodeData } from 'features/nodes/store/selectors'; +import type { InvocationNodeData } from 'features/nodes/types/invocation'; import { useMemo } from 'react'; -export const useNodeData = (nodeId: string) => { +export const useNodeData = (nodeId: string): InvocationNodeData | null => { const selector = useMemo( () => createMemoizedSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - return node?.data; + return selectNodeData(nodes, nodeId); }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeLabel.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeLabel.ts index c5fc43742a1..31dcb9c466e 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeLabel.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeLabel.ts @@ -1,19 +1,14 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeData } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; export const useNodeLabel = (nodeId: string) => { const selector = useMemo( () => createSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return false; - } - - return node.data.label; + return selectNodeData(nodes, nodeId)?.label ?? null; }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeNeedsUpdate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeNeedsUpdate.ts index e6efa667f12..aa0294f70f0 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeNeedsUpdate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeNeedsUpdate.ts @@ -1,21 +1,20 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectInvocationNode, selectNodeTemplate } from 'features/nodes/store/selectors'; import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate'; import { useMemo } from 'react'; export const useNodeNeedsUpdate = (nodeId: string) => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - const template = nodeTemplates.templates[node?.data.type ?? '']; - if (isInvocationNode(node) && template) { - return getNeedsUpdate(node, template); + createMemoizedSelector(selectNodesSlice, (nodes) => { + const node = selectInvocationNode(nodes, nodeId); + const template = selectNodeTemplate(nodes, nodeId); + if (!node || !template) { + return false; } - return false; + return getNeedsUpdate(node, template); }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodePack.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodePack.ts index ca3dd5cfdf6..5c920866e9d 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodePack.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodePack.ts @@ -1,18 +1,14 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeData } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; -export const useNodePack = (nodeId: string) => { +export const useNodePack = (nodeId: string): string | null => { const selector = useMemo( () => createSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return false; - } - return node.data.nodePack; + return selectNodeData(nodes, nodeId)?.nodePack ?? null; }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplate.ts index 7544cbff461..866c9275fb3 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplate.ts @@ -1,16 +1,15 @@ -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; +import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; +import { selectNodeTemplate } from 'features/nodes/store/selectors'; +import type { InvocationTemplate } from 'features/nodes/types/invocation'; import { useMemo } from 'react'; -export const useNodeTemplate = (nodeId: string) => { +export const useNodeTemplate = (nodeId: string): InvocationTemplate | null => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - return nodeTemplate; + createSelector(selectNodesSlice, (nodes) => { + return selectNodeTemplate(nodes, nodeId); }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateByType.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateByType.ts index 8fd1345f6f5..a0c870f6941 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateByType.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateByType.ts @@ -1,14 +1,14 @@ -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; +import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; +import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; import type { InvocationTemplate } from 'features/nodes/types/invocation'; import { useMemo } from 'react'; -export const useNodeTemplateByType = (type: string) => { +export const useNodeTemplateByType = (type: string): InvocationTemplate | null => { const selector = useMemo( () => - createMemoizedSelector(selectNodeTemplatesSlice, (nodeTemplates): InvocationTemplate | undefined => { - return nodeTemplates.templates[type]; + createSelector(selectNodesSlice, (nodes) => { + return nodes.templates[type] ?? null; }), [type] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts index 15d2ec38c32..120b8c758be 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts @@ -1,21 +1,14 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeTemplate } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; -export const useNodeTemplateTitle = (nodeId: string) => { +export const useNodeTemplateTitle = (nodeId: string): string | null => { const selector = useMemo( () => - createSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return false; - } - const nodeTemplate = node ? nodeTemplates.templates[node.data.type] : undefined; - - return nodeTemplate?.title; + createSelector(selectNodesSlice, (nodes) => { + return selectNodeTemplate(nodes, nodeId)?.title ?? null; }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts index e352bd8b90f..24863080a74 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts @@ -1,8 +1,8 @@ -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; +import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; +import { EMPTY_ARRAY } from 'app/store/util'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeTemplate } from 'features/nodes/store/selectors'; import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames'; import { map } from 'lodash-es'; import { useMemo } from 'react'; @@ -10,17 +10,13 @@ import { useMemo } from 'react'; export const useOutputFieldNames = (nodeId: string) => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return []; - } - const nodeTemplate = nodeTemplates.templates[node.data.type]; - if (!nodeTemplate) { - return []; + createSelector(selectNodesSlice, (nodes) => { + const template = selectNodeTemplate(nodes, nodeId); + if (!template) { + return EMPTY_ARRAY; } - return getSortedFilteredFieldNames(map(nodeTemplate.outputs)); + return getSortedFilteredFieldNames(map(template.outputs)); }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useUseCache.ts b/invokeai/frontend/web/src/features/nodes/hooks/useUseCache.ts index edfc990882b..aaca80039b0 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useUseCache.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useUseCache.ts @@ -1,18 +1,14 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeData } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; export const useUseCache = (nodeId: string) => { const selector = useMemo( () => createSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return false; - } - return node.data.useCache; + return selectNodeData(nodes, nodeId)?.useCache ?? false; }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useWorkflowWatcher.ts b/invokeai/frontend/web/src/features/nodes/hooks/useWorkflowWatcher.ts index 0e4806d81b7..5d79c154428 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useWorkflowWatcher.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useWorkflowWatcher.ts @@ -2,14 +2,14 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; import { selectWorkflowSlice } from 'features/nodes/store/workflowSlice'; -import type { WorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowV3 } from 'features/nodes/types/workflow'; import type { BuildWorkflowArg } from 'features/nodes/util/workflow/buildWorkflow'; import { buildWorkflowFast } from 'features/nodes/util/workflow/buildWorkflow'; import { debounce } from 'lodash-es'; import { atom } from 'nanostores'; import { useEffect } from 'react'; -export const $builtWorkflow = atom(null); +export const $builtWorkflow = atom(null); const debouncedBuildWorkflow = debounce((arg: BuildWorkflowArg) => { $builtWorkflow.set(buildWorkflowFast(arg)); diff --git a/invokeai/frontend/web/src/features/nodes/store/actions.ts b/invokeai/frontend/web/src/features/nodes/store/actions.ts index 00457494bfb..b32a3ba9979 100644 --- a/invokeai/frontend/web/src/features/nodes/store/actions.ts +++ b/invokeai/frontend/web/src/features/nodes/store/actions.ts @@ -1,5 +1,5 @@ import { createAction, isAnyOf } from '@reduxjs/toolkit'; -import type { WorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowV3 } from 'features/nodes/types/workflow'; import type { Graph } from 'services/api/types'; export const textToImageGraphBuilt = createAction('nodes/textToImageGraphBuilt'); @@ -21,4 +21,4 @@ export const workflowLoadRequested = createAction<{ export const updateAllNodesRequested = createAction('nodes/updateAllNodesRequested'); -export const workflowLoaded = createAction('workflow/workflowLoaded'); +export const workflowLoaded = createAction('workflow/workflowLoaded'); diff --git a/invokeai/frontend/web/src/features/nodes/store/nodeTemplatesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodeTemplatesSlice.ts deleted file mode 100644 index c211131aab7..00000000000 --- a/invokeai/frontend/web/src/features/nodes/store/nodeTemplatesSlice.ts +++ /dev/null @@ -1,24 +0,0 @@ -import type { PayloadAction } from '@reduxjs/toolkit'; -import { createSlice } from '@reduxjs/toolkit'; -import type { RootState } from 'app/store/store'; -import type { InvocationTemplate } from 'features/nodes/types/invocation'; - -import type { NodeTemplatesState } from './types'; - -export const initialNodeTemplatesState: NodeTemplatesState = { - templates: {}, -}; - -export const nodesTemplatesSlice = createSlice({ - name: 'nodeTemplates', - initialState: initialNodeTemplatesState, - reducers: { - nodeTemplatesBuilt: (state, action: PayloadAction>) => { - state.templates = action.payload; - }, - }, -}); - -export const { nodeTemplatesBuilt } = nodesTemplatesSlice.actions; - -export const selectNodeTemplatesSlice = (state: RootState) => state.nodeTemplates; diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index aee01b381ba..6b596da0633 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -42,7 +42,7 @@ import { zT2IAdapterModelFieldValue, zVAEModelFieldValue, } from 'features/nodes/types/field'; -import type { AnyNode, NodeExecutionState } from 'features/nodes/types/invocation'; +import type { AnyNode, InvocationTemplate, NodeExecutionState } from 'features/nodes/types/invocation'; import { isInvocationNode, isNotesNode, zNodeStatus } from 'features/nodes/types/invocation'; import { cloneDeep, forEach } from 'lodash-es'; import type { @@ -92,6 +92,7 @@ export const initialNodesState: NodesState = { _version: 1, nodes: [], edges: [], + templates: {}, connectionStartParams: null, connectionStartFieldType: null, connectionMade: false, @@ -190,6 +191,7 @@ export const nodesSlice = createSlice({ node, state.nodes, state.edges, + state.templates, nodeId, handleId, handleType, @@ -224,12 +226,12 @@ export const nodesSlice = createSlice({ if (!nodeId || !handleId) { return; } - const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId); - const node = state.nodes?.[nodeIndex]; + const node = state.nodes.find((n) => n.id === nodeId); if (!isInvocationNode(node)) { return; } - const field = handleType === 'source' ? node.data.outputs[handleId] : node.data.inputs[handleId]; + const template = state.templates[node.data.type]; + const field = handleType === 'source' ? template?.outputs[handleId] : template?.inputs[handleId]; state.connectionStartFieldType = field?.type ?? null; }, connectionMade: (state, action: PayloadAction) => { @@ -260,6 +262,7 @@ export const nodesSlice = createSlice({ mouseOverNode, state.nodes, state.edges, + state.templates, nodeId, handleId, handleType, @@ -677,6 +680,9 @@ export const nodesSlice = createSlice({ selectionModeChanged: (state, action: PayloadAction) => { state.selectionMode = action.payload ? SelectionMode.Full : SelectionMode.Partial; }, + nodeTemplatesBuilt: (state, action: PayloadAction>) => { + state.templates = action.payload; + }, }, extraReducers: (builder) => { builder.addCase(workflowLoaded, (state, action) => { @@ -808,6 +814,7 @@ export const { shouldValidateGraphChanged, viewportChanged, edgeAdded, + nodeTemplatesBuilt, } = nodesSlice.actions; // This is used for tracking `state.workflow.isTouched` diff --git a/invokeai/frontend/web/src/features/nodes/store/selectors.ts b/invokeai/frontend/web/src/features/nodes/store/selectors.ts new file mode 100644 index 00000000000..90675d62707 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/selectors.ts @@ -0,0 +1,51 @@ +import type { NodesState } from 'features/nodes/store/types'; +import type { FieldInputInstance, FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field'; +import type { InvocationNode, InvocationNodeData, InvocationTemplate } from 'features/nodes/types/invocation'; +import { isInvocationNode } from 'features/nodes/types/invocation'; + +export const selectInvocationNode = (nodesSlice: NodesState, nodeId: string): InvocationNode | null => { + const node = nodesSlice.nodes.find((node) => node.id === nodeId); + if (!isInvocationNode(node)) { + return null; + } + return node; +}; + +export const selectNodeData = (nodesSlice: NodesState, nodeId: string): InvocationNodeData | null => { + return selectInvocationNode(nodesSlice, nodeId)?.data ?? null; +}; + +export const selectNodeTemplate = (nodesSlice: NodesState, nodeId: string): InvocationTemplate | null => { + const node = selectInvocationNode(nodesSlice, nodeId); + if (!node) { + return null; + } + return nodesSlice.templates[node.data.type] ?? null; +}; + +export const selectFieldInputInstance = ( + nodesSlice: NodesState, + nodeId: string, + fieldName: string +): FieldInputInstance | null => { + const data = selectNodeData(nodesSlice, nodeId); + return data?.inputs[fieldName] ?? null; +}; + +export const selectFieldInputTemplate = ( + nodesSlice: NodesState, + nodeId: string, + fieldName: string +): FieldInputTemplate | null => { + const template = selectNodeTemplate(nodesSlice, nodeId); + return template?.inputs[fieldName] ?? null; +}; + +export const selectFieldOutputTemplate = ( + nodesSlice: NodesState, + nodeId: string, + fieldName: string +): FieldOutputTemplate | null => { + const template = selectNodeTemplate(nodesSlice, nodeId); + return template?.outputs[fieldName] ?? null; +}; diff --git a/invokeai/frontend/web/src/features/nodes/store/types.ts b/invokeai/frontend/web/src/features/nodes/store/types.ts index 8b0de447e43..1a040d2c705 100644 --- a/invokeai/frontend/web/src/features/nodes/store/types.ts +++ b/invokeai/frontend/web/src/features/nodes/store/types.ts @@ -5,13 +5,14 @@ import type { InvocationTemplate, NodeExecutionState, } from 'features/nodes/types/invocation'; -import type { WorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowV3 } from 'features/nodes/types/workflow'; import type { OnConnectStartParams, SelectionMode, Viewport, XYPosition } from 'reactflow'; export type NodesState = { _version: 1; nodes: AnyNode[]; edges: InvocationNodeEdge[]; + templates: Record; connectionStartParams: OnConnectStartParams | null; connectionStartFieldType: FieldType | null; connectionMade: boolean; @@ -38,7 +39,7 @@ export type FieldIdentifierWithValue = FieldIdentifier & { value: StatefulFieldValue; }; -export type WorkflowsState = Omit & { +export type WorkflowsState = Omit & { _version: 1; isTouched: boolean; mode: WorkflowMode; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts b/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts index 9f2c37a2ad7..ef899c5f414 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts @@ -1,4 +1,6 @@ -import type { FieldInputInstance, FieldOutputInstance, FieldType } from 'features/nodes/types/field'; +import type { FieldInputTemplate, FieldOutputTemplate, FieldType } from 'features/nodes/types/field'; +import type { AnyNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation'; +import { isInvocationNode } from 'features/nodes/types/invocation'; import type { Connection, Edge, HandleType, Node } from 'reactflow'; import { getIsGraphAcyclic } from './getIsGraphAcyclic'; @@ -9,7 +11,7 @@ const isValidConnection = ( handleCurrentType: HandleType, handleCurrentFieldType: FieldType, node: Node, - handle: FieldInputInstance | FieldOutputInstance + handle: FieldInputTemplate | FieldOutputTemplate ) => { let isValidConnection = true; if (handleCurrentType === 'source') { @@ -38,24 +40,31 @@ const isValidConnection = ( }; export const findConnectionToValidHandle = ( - node: Node, - nodes: Node[], - edges: Edge[], + node: AnyNode, + nodes: AnyNode[], + edges: InvocationNodeEdge[], + templates: Record, handleCurrentNodeId: string, handleCurrentName: string, handleCurrentType: HandleType, handleCurrentFieldType: FieldType ): Connection | null => { - if (node.id === handleCurrentNodeId) { + if (node.id === handleCurrentNodeId || !isInvocationNode(node)) { return null; } - const handles = handleCurrentType === 'source' ? node.data.inputs : node.data.outputs; + const template = templates[node.data.type]; + + if (!template) { + return null; + } + + const handles = handleCurrentType === 'source' ? template.inputs : template.outputs; //Prioritize handles whos name matches the node we're coming from - if (handles[handleCurrentName]) { - const handle = handles[handleCurrentName]; + const handle = handles[handleCurrentName]; + if (handle) { const sourceID = handleCurrentType === 'source' ? handleCurrentNodeId : node.id; const targetID = handleCurrentType === 'source' ? node.id : handleCurrentNodeId; const sourceHandle = handleCurrentType === 'source' ? handleCurrentName : handle.name; @@ -77,6 +86,9 @@ export const findConnectionToValidHandle = ( for (const handleName in handles) { const handle = handles[handleName]; + if (!handle) { + continue; + } const sourceID = handleCurrentType === 'source' ? handleCurrentNodeId : node.id; const targetID = handleCurrentType === 'source' ? node.id : handleCurrentNodeId; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts index 8575932cbdd..d6ea0d9c86e 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts @@ -16,7 +16,7 @@ export const makeConnectionErrorSelector = ( nodeId: string, fieldName: string, handleType: HandleType, - fieldType?: FieldType + fieldType?: FieldType | null ) => { return createSelector(selectNodesSlice, (nodesSlice) => { if (!fieldType) { diff --git a/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts b/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts index 2978f25138d..4f40a68e1f0 100644 --- a/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts @@ -10,10 +10,10 @@ import type { } from 'features/nodes/store/types'; import type { FieldIdentifier } from 'features/nodes/types/field'; import { isInvocationNode } from 'features/nodes/types/invocation'; -import type { WorkflowCategory, WorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowCategory, WorkflowV3 } from 'features/nodes/types/workflow'; import { cloneDeep, isEqual, omit, uniqBy } from 'lodash-es'; -export const blankWorkflow: Omit = { +export const blankWorkflow: Omit = { name: '', author: '', description: '', @@ -22,7 +22,7 @@ export const blankWorkflow: Omit = { tags: '', notes: '', exposedFields: [], - meta: { version: '2.0.0', category: 'user' }, + meta: { version: '3.0.0', category: 'user' }, id: undefined, }; diff --git a/invokeai/frontend/web/src/features/nodes/types/field.ts b/invokeai/frontend/web/src/features/nodes/types/field.ts index 38f1af55dd8..aa6164d6e53 100644 --- a/invokeai/frontend/web/src/features/nodes/types/field.ts +++ b/invokeai/frontend/web/src/features/nodes/types/field.ts @@ -46,20 +46,11 @@ export type FieldInput = z.infer; export const zFieldUIComponent = z.enum(['none', 'textarea', 'slider']); export type FieldUIComponent = z.infer; -export const zFieldInstanceBase = z.object({ - id: z.string().trim().min(1), +export const zFieldInputInstanceBase = z.object({ name: z.string().trim().min(1), -}); -export const zFieldInputInstanceBase = zFieldInstanceBase.extend({ - fieldKind: z.literal('input'), label: z.string().nullish(), }); -export const zFieldOutputInstanceBase = zFieldInstanceBase.extend({ - fieldKind: z.literal('output'), -}); -export type FieldInstanceBase = z.infer; export type FieldInputInstanceBase = z.infer; -export type FieldOutputInstanceBase = z.infer; export const zFieldTemplateBase = z.object({ name: z.string().min(1), @@ -102,12 +93,8 @@ export const zIntegerFieldType = zFieldTypeBase.extend({ }); export const zIntegerFieldValue = z.number().int(); export const zIntegerFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zIntegerFieldType, value: zIntegerFieldValue, }); -export const zIntegerFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zIntegerFieldType, -}); export const zIntegerFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zIntegerFieldType, default: zIntegerFieldValue, @@ -136,12 +123,8 @@ export const zFloatFieldType = zFieldTypeBase.extend({ }); export const zFloatFieldValue = z.number(); export const zFloatFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zFloatFieldType, value: zFloatFieldValue, }); -export const zFloatFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zFloatFieldType, -}); export const zFloatFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zFloatFieldType, default: zFloatFieldValue, @@ -157,7 +140,6 @@ export const zFloatFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type FloatFieldType = z.infer; export type FloatFieldValue = z.infer; export type FloatFieldInputInstance = z.infer; -export type FloatFieldOutputInstance = z.infer; export type FloatFieldInputTemplate = z.infer; export type FloatFieldOutputTemplate = z.infer; export const isFloatFieldInputInstance = (val: unknown): val is FloatFieldInputInstance => @@ -172,12 +154,8 @@ export const zStringFieldType = zFieldTypeBase.extend({ }); export const zStringFieldValue = z.string(); export const zStringFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zStringFieldType, value: zStringFieldValue, }); -export const zStringFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zStringFieldType, -}); export const zStringFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zStringFieldType, default: zStringFieldValue, @@ -191,7 +169,6 @@ export const zStringFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type StringFieldType = z.infer; export type StringFieldValue = z.infer; export type StringFieldInputInstance = z.infer; -export type StringFieldOutputInstance = z.infer; export type StringFieldInputTemplate = z.infer; export type StringFieldOutputTemplate = z.infer; export const isStringFieldInputInstance = (val: unknown): val is StringFieldInputInstance => @@ -206,12 +183,8 @@ export const zBooleanFieldType = zFieldTypeBase.extend({ }); export const zBooleanFieldValue = z.boolean(); export const zBooleanFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zBooleanFieldType, value: zBooleanFieldValue, }); -export const zBooleanFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zBooleanFieldType, -}); export const zBooleanFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zBooleanFieldType, default: zBooleanFieldValue, @@ -222,7 +195,6 @@ export const zBooleanFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type BooleanFieldType = z.infer; export type BooleanFieldValue = z.infer; export type BooleanFieldInputInstance = z.infer; -export type BooleanFieldOutputInstance = z.infer; export type BooleanFieldInputTemplate = z.infer; export type BooleanFieldOutputTemplate = z.infer; export const isBooleanFieldInputInstance = (val: unknown): val is BooleanFieldInputInstance => @@ -237,12 +209,8 @@ export const zEnumFieldType = zFieldTypeBase.extend({ }); export const zEnumFieldValue = z.string(); export const zEnumFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zEnumFieldType, value: zEnumFieldValue, }); -export const zEnumFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zEnumFieldType, -}); export const zEnumFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zEnumFieldType, default: zEnumFieldValue, @@ -255,7 +223,6 @@ export const zEnumFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type EnumFieldType = z.infer; export type EnumFieldValue = z.infer; export type EnumFieldInputInstance = z.infer; -export type EnumFieldOutputInstance = z.infer; export type EnumFieldInputTemplate = z.infer; export type EnumFieldOutputTemplate = z.infer; export const isEnumFieldInputInstance = (val: unknown): val is EnumFieldInputInstance => @@ -270,12 +237,8 @@ export const zImageFieldType = zFieldTypeBase.extend({ }); export const zImageFieldValue = zImageField.optional(); export const zImageFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zImageFieldType, value: zImageFieldValue, }); -export const zImageFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zImageFieldType, -}); export const zImageFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zImageFieldType, default: zImageFieldValue, @@ -286,7 +249,6 @@ export const zImageFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type ImageFieldType = z.infer; export type ImageFieldValue = z.infer; export type ImageFieldInputInstance = z.infer; -export type ImageFieldOutputInstance = z.infer; export type ImageFieldInputTemplate = z.infer; export type ImageFieldOutputTemplate = z.infer; export const isImageFieldInputInstance = (val: unknown): val is ImageFieldInputInstance => @@ -301,12 +263,8 @@ export const zBoardFieldType = zFieldTypeBase.extend({ }); export const zBoardFieldValue = zBoardField.optional(); export const zBoardFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zBoardFieldType, value: zBoardFieldValue, }); -export const zBoardFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zBoardFieldType, -}); export const zBoardFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zBoardFieldType, default: zBoardFieldValue, @@ -317,7 +275,6 @@ export const zBoardFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type BoardFieldType = z.infer; export type BoardFieldValue = z.infer; export type BoardFieldInputInstance = z.infer; -export type BoardFieldOutputInstance = z.infer; export type BoardFieldInputTemplate = z.infer; export type BoardFieldOutputTemplate = z.infer; export const isBoardFieldInputInstance = (val: unknown): val is BoardFieldInputInstance => @@ -332,12 +289,8 @@ export const zColorFieldType = zFieldTypeBase.extend({ }); export const zColorFieldValue = zColorField.optional(); export const zColorFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zColorFieldType, value: zColorFieldValue, }); -export const zColorFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zColorFieldType, -}); export const zColorFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zColorFieldType, default: zColorFieldValue, @@ -348,7 +301,6 @@ export const zColorFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type ColorFieldType = z.infer; export type ColorFieldValue = z.infer; export type ColorFieldInputInstance = z.infer; -export type ColorFieldOutputInstance = z.infer; export type ColorFieldInputTemplate = z.infer; export type ColorFieldOutputTemplate = z.infer; export const isColorFieldInputInstance = (val: unknown): val is ColorFieldInputInstance => @@ -363,12 +315,8 @@ export const zMainModelFieldType = zFieldTypeBase.extend({ }); export const zMainModelFieldValue = zMainModelField.optional(); export const zMainModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zMainModelFieldType, value: zMainModelFieldValue, }); -export const zMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zMainModelFieldType, -}); export const zMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zMainModelFieldType, default: zMainModelFieldValue, @@ -379,7 +327,6 @@ export const zMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type MainModelFieldType = z.infer; export type MainModelFieldValue = z.infer; export type MainModelFieldInputInstance = z.infer; -export type MainModelFieldOutputInstance = z.infer; export type MainModelFieldInputTemplate = z.infer; export type MainModelFieldOutputTemplate = z.infer; export const isMainModelFieldInputInstance = (val: unknown): val is MainModelFieldInputInstance => @@ -394,12 +341,8 @@ export const zSDXLMainModelFieldType = zFieldTypeBase.extend({ }); export const zSDXLMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only. export const zSDXLMainModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zSDXLMainModelFieldType, value: zSDXLMainModelFieldValue, }); -export const zSDXLMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zSDXLMainModelFieldType, -}); export const zSDXLMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zSDXLMainModelFieldType, default: zSDXLMainModelFieldValue, @@ -410,7 +353,6 @@ export const zSDXLMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend export type SDXLMainModelFieldType = z.infer; export type SDXLMainModelFieldValue = z.infer; export type SDXLMainModelFieldInputInstance = z.infer; -export type SDXLMainModelFieldOutputInstance = z.infer; export type SDXLMainModelFieldInputTemplate = z.infer; export type SDXLMainModelFieldOutputTemplate = z.infer; export const isSDXLMainModelFieldInputInstance = (val: unknown): val is SDXLMainModelFieldInputInstance => @@ -425,12 +367,8 @@ export const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({ }); export const zSDXLRefinerModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL Refiner models only. export const zSDXLRefinerModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zSDXLRefinerModelFieldType, value: zSDXLRefinerModelFieldValue, }); -export const zSDXLRefinerModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zSDXLRefinerModelFieldType, -}); export const zSDXLRefinerModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zSDXLRefinerModelFieldType, default: zSDXLRefinerModelFieldValue, @@ -441,7 +379,6 @@ export const zSDXLRefinerModelFieldOutputTemplate = zFieldOutputTemplateBase.ext export type SDXLRefinerModelFieldType = z.infer; export type SDXLRefinerModelFieldValue = z.infer; export type SDXLRefinerModelFieldInputInstance = z.infer; -export type SDXLRefinerModelFieldOutputInstance = z.infer; export type SDXLRefinerModelFieldInputTemplate = z.infer; export type SDXLRefinerModelFieldOutputTemplate = z.infer; export const isSDXLRefinerModelFieldInputInstance = (val: unknown): val is SDXLRefinerModelFieldInputInstance => @@ -456,12 +393,8 @@ export const zVAEModelFieldType = zFieldTypeBase.extend({ }); export const zVAEModelFieldValue = zVAEModelField.optional(); export const zVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zVAEModelFieldType, value: zVAEModelFieldValue, }); -export const zVAEModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zVAEModelFieldType, -}); export const zVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zVAEModelFieldType, default: zVAEModelFieldValue, @@ -472,7 +405,6 @@ export const zVAEModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type VAEModelFieldType = z.infer; export type VAEModelFieldValue = z.infer; export type VAEModelFieldInputInstance = z.infer; -export type VAEModelFieldOutputInstance = z.infer; export type VAEModelFieldInputTemplate = z.infer; export type VAEModelFieldOutputTemplate = z.infer; export const isVAEModelFieldInputInstance = (val: unknown): val is VAEModelFieldInputInstance => @@ -487,12 +419,8 @@ export const zLoRAModelFieldType = zFieldTypeBase.extend({ }); export const zLoRAModelFieldValue = zLoRAModelField.optional(); export const zLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zLoRAModelFieldType, value: zLoRAModelFieldValue, }); -export const zLoRAModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zLoRAModelFieldType, -}); export const zLoRAModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zLoRAModelFieldType, default: zLoRAModelFieldValue, @@ -503,7 +431,6 @@ export const zLoRAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type LoRAModelFieldType = z.infer; export type LoRAModelFieldValue = z.infer; export type LoRAModelFieldInputInstance = z.infer; -export type LoRAModelFieldOutputInstance = z.infer; export type LoRAModelFieldInputTemplate = z.infer; export type LoRAModelFieldOutputTemplate = z.infer; export const isLoRAModelFieldInputInstance = (val: unknown): val is LoRAModelFieldInputInstance => @@ -518,12 +445,8 @@ export const zControlNetModelFieldType = zFieldTypeBase.extend({ }); export const zControlNetModelFieldValue = zControlNetModelField.optional(); export const zControlNetModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zControlNetModelFieldType, value: zControlNetModelFieldValue, }); -export const zControlNetModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zControlNetModelFieldType, -}); export const zControlNetModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zControlNetModelFieldType, default: zControlNetModelFieldValue, @@ -534,7 +457,6 @@ export const zControlNetModelFieldOutputTemplate = zFieldOutputTemplateBase.exte export type ControlNetModelFieldType = z.infer; export type ControlNetModelFieldValue = z.infer; export type ControlNetModelFieldInputInstance = z.infer; -export type ControlNetModelFieldOutputInstance = z.infer; export type ControlNetModelFieldInputTemplate = z.infer; export type ControlNetModelFieldOutputTemplate = z.infer; export const isControlNetModelFieldInputInstance = (val: unknown): val is ControlNetModelFieldInputInstance => @@ -551,12 +473,8 @@ export const zIPAdapterModelFieldType = zFieldTypeBase.extend({ }); export const zIPAdapterModelFieldValue = zIPAdapterModelField.optional(); export const zIPAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zIPAdapterModelFieldType, value: zIPAdapterModelFieldValue, }); -export const zIPAdapterModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zIPAdapterModelFieldType, -}); export const zIPAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zIPAdapterModelFieldType, default: zIPAdapterModelFieldValue, @@ -567,7 +485,6 @@ export const zIPAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.exten export type IPAdapterModelFieldType = z.infer; export type IPAdapterModelFieldValue = z.infer; export type IPAdapterModelFieldInputInstance = z.infer; -export type IPAdapterModelFieldOutputInstance = z.infer; export type IPAdapterModelFieldInputTemplate = z.infer; export type IPAdapterModelFieldOutputTemplate = z.infer; export const isIPAdapterModelFieldInputInstance = (val: unknown): val is IPAdapterModelFieldInputInstance => @@ -584,12 +501,8 @@ export const zT2IAdapterModelFieldType = zFieldTypeBase.extend({ }); export const zT2IAdapterModelFieldValue = zT2IAdapterModelField.optional(); export const zT2IAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zT2IAdapterModelFieldType, value: zT2IAdapterModelFieldValue, }); -export const zT2IAdapterModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zT2IAdapterModelFieldType, -}); export const zT2IAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zT2IAdapterModelFieldType, default: zT2IAdapterModelFieldValue, @@ -600,7 +513,6 @@ export const zT2IAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.exte export type T2IAdapterModelFieldType = z.infer; export type T2IAdapterModelFieldValue = z.infer; export type T2IAdapterModelFieldInputInstance = z.infer; -export type T2IAdapterModelFieldOutputInstance = z.infer; export type T2IAdapterModelFieldInputTemplate = z.infer; export type T2IAdapterModelFieldOutputTemplate = z.infer; export const isT2IAdapterModelFieldInputInstance = (val: unknown): val is T2IAdapterModelFieldInputInstance => @@ -615,12 +527,8 @@ export const zSchedulerFieldType = zFieldTypeBase.extend({ }); export const zSchedulerFieldValue = zSchedulerField.optional(); export const zSchedulerFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zSchedulerFieldType, value: zSchedulerFieldValue, }); -export const zSchedulerFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zSchedulerFieldType, -}); export const zSchedulerFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zSchedulerFieldType, default: zSchedulerFieldValue, @@ -631,7 +539,6 @@ export const zSchedulerFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type SchedulerFieldType = z.infer; export type SchedulerFieldValue = z.infer; export type SchedulerFieldInputInstance = z.infer; -export type SchedulerFieldOutputInstance = z.infer; export type SchedulerFieldInputTemplate = z.infer; export type SchedulerFieldOutputTemplate = z.infer; export const isSchedulerFieldInputInstance = (val: unknown): val is SchedulerFieldInputInstance => @@ -657,12 +564,8 @@ export const zStatelessFieldType = zFieldTypeBase.extend({ }); export const zStatelessFieldValue = z.undefined().catch(undefined); // stateless --> no value, but making this z.never() introduces a lot of extra TS fanagling export const zStatelessFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zStatelessFieldType, value: zStatelessFieldValue, }); -export const zStatelessFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zStatelessFieldType, -}); export const zStatelessFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zStatelessFieldType, default: zStatelessFieldValue, @@ -675,7 +578,6 @@ export const zStatelessFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type StatelessFieldType = z.infer; export type StatelessFieldValue = z.infer; export type StatelessFieldInputInstance = z.infer; -export type StatelessFieldOutputInstance = z.infer; export type StatelessFieldInputTemplate = z.infer; export type StatelessFieldOutputTemplate = z.infer; // #endregion @@ -783,36 +685,6 @@ export const isFieldInputInstance = (val: unknown): val is FieldInputInstance => zFieldInputInstance.safeParse(val).success; // #endregion -// #region StatefulFieldOutputInstance & FieldOutputInstance -export const zStatefulFieldOutputInstance = z.union([ - zIntegerFieldOutputInstance, - zFloatFieldOutputInstance, - zStringFieldOutputInstance, - zBooleanFieldOutputInstance, - zEnumFieldOutputInstance, - zImageFieldOutputInstance, - zBoardFieldOutputInstance, - zMainModelFieldOutputInstance, - zSDXLMainModelFieldOutputInstance, - zSDXLRefinerModelFieldOutputInstance, - zVAEModelFieldOutputInstance, - zLoRAModelFieldOutputInstance, - zControlNetModelFieldOutputInstance, - zIPAdapterModelFieldOutputInstance, - zT2IAdapterModelFieldOutputInstance, - zColorFieldOutputInstance, - zSchedulerFieldOutputInstance, -]); -export type StatefulFieldOutputInstance = z.infer; -export const isStatefulFieldOutputInstance = (val: unknown): val is StatefulFieldOutputInstance => - zStatefulFieldOutputInstance.safeParse(val).success; - -export const zFieldOutputInstance = z.union([zStatefulFieldOutputInstance, zStatelessFieldOutputInstance]); -export type FieldOutputInstance = z.infer; -export const isFieldOutputInstance = (val: unknown): val is FieldOutputInstance => - zFieldOutputInstance.safeParse(val).success; -// #endregion - // #region StatefulFieldInputTemplate & FieldInputTemplate export const zStatefulFieldInputTemplate = z.union([ zIntegerFieldInputTemplate, diff --git a/invokeai/frontend/web/src/features/nodes/types/invocation.ts b/invokeai/frontend/web/src/features/nodes/types/invocation.ts index 86ec70fd9bd..5ccb19430da 100644 --- a/invokeai/frontend/web/src/features/nodes/types/invocation.ts +++ b/invokeai/frontend/web/src/features/nodes/types/invocation.ts @@ -2,7 +2,7 @@ import type { Edge, Node } from 'reactflow'; import { z } from 'zod'; import { zClassification, zProgressImage } from './common'; -import { zFieldInputInstance, zFieldInputTemplate, zFieldOutputInstance, zFieldOutputTemplate } from './field'; +import { zFieldInputInstance, zFieldInputTemplate, zFieldOutputTemplate } from './field'; import { zSemVer } from './semver'; // #region InvocationTemplate @@ -25,16 +25,15 @@ export type InvocationTemplate = z.infer; // #region NodeData export const zInvocationNodeData = z.object({ id: z.string().trim().min(1), - type: z.string().trim().min(1), + version: zSemVer, + nodePack: z.string().min(1).nullish(), label: z.string(), - isOpen: z.boolean(), notes: z.string(), + type: z.string().trim().min(1), + inputs: z.record(zFieldInputInstance), + isOpen: z.boolean(), isIntermediate: z.boolean(), useCache: z.boolean(), - version: zSemVer, - nodePack: z.string().min(1).nullish(), - inputs: z.record(zFieldInputInstance), - outputs: z.record(zFieldOutputInstance), }); export const zNotesNodeData = z.object({ @@ -62,11 +61,12 @@ export type NotesNode = Node; export type CurrentImageNode = Node; export type AnyNode = Node; -export const isInvocationNode = (node?: AnyNode): node is InvocationNode => Boolean(node && node.type === 'invocation'); -export const isNotesNode = (node?: AnyNode): node is NotesNode => Boolean(node && node.type === 'notes'); -export const isCurrentImageNode = (node?: AnyNode): node is CurrentImageNode => +export const isInvocationNode = (node?: AnyNode | null): node is InvocationNode => + Boolean(node && node.type === 'invocation'); +export const isNotesNode = (node?: AnyNode | null): node is NotesNode => Boolean(node && node.type === 'notes'); +export const isCurrentImageNode = (node?: AnyNode | null): node is CurrentImageNode => Boolean(node && node.type === 'current_image'); -export const isInvocationNodeData = (node?: AnyNodeData): node is InvocationNodeData => +export const isInvocationNodeData = (node?: AnyNodeData | null): node is InvocationNodeData => Boolean(node && !['notes', 'current_image'].includes(node.type)); // node.type may be 'notes', 'current_image', or any invocation type // #endregion diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/common.ts b/invokeai/frontend/web/src/features/nodes/types/v2/common.ts new file mode 100644 index 00000000000..b5244743799 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/common.ts @@ -0,0 +1,188 @@ +import { z } from 'zod'; + +// #region Field data schemas +export const zImageField = z.object({ + image_name: z.string().trim().min(1), +}); +export type ImageField = z.infer; + +export const zBoardField = z.object({ + board_id: z.string().trim().min(1), +}); +export type BoardField = z.infer; + +export const zColorField = z.object({ + r: z.number().int().min(0).max(255), + g: z.number().int().min(0).max(255), + b: z.number().int().min(0).max(255), + a: z.number().int().min(0).max(255), +}); +export type ColorField = z.infer; + +export const zClassification = z.enum(['stable', 'beta', 'prototype']); +export type Classification = z.infer; + +export const zSchedulerField = z.enum([ + 'euler', + 'deis', + 'ddim', + 'ddpm', + 'dpmpp_2s', + 'dpmpp_2m', + 'dpmpp_2m_sde', + 'dpmpp_sde', + 'heun', + 'kdpm_2', + 'lms', + 'pndm', + 'unipc', + 'euler_k', + 'dpmpp_2s_k', + 'dpmpp_2m_k', + 'dpmpp_2m_sde_k', + 'dpmpp_sde_k', + 'heun_k', + 'lms_k', + 'euler_a', + 'kdpm_2_a', + 'lcm', +]); +export type SchedulerField = z.infer; +// #endregion + +// #region Model-related schemas +export const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']); +export const zModelType = z.enum(['main', 'vae', 'lora', 'controlnet', 'embedding']); +export const zModelName = z.string().min(3); +export const zModelIdentifier = z.object({ + model_name: zModelName, + base_model: zBaseModel, +}); +export type BaseModel = z.infer; +export type ModelType = z.infer; +export type ModelIdentifier = z.infer; + +export const zMainModelField = z.object({ + model_name: zModelName, + base_model: zBaseModel, + model_type: z.literal('main'), +}); +export const zSDXLRefinerModelField = z.object({ + model_name: z.string().min(1), + base_model: z.literal('sdxl-refiner'), + model_type: z.literal('main'), +}); +export type MainModelField = z.infer; +export type SDXLRefinerModelField = z.infer; + +export const zSubModelType = z.enum([ + 'unet', + 'text_encoder', + 'text_encoder_2', + 'tokenizer', + 'tokenizer_2', + 'vae', + 'vae_decoder', + 'vae_encoder', + 'scheduler', + 'safety_checker', +]); +export type SubModelType = z.infer; + +export const zVAEModelField = zModelIdentifier; + +export const zModelInfo = zModelIdentifier.extend({ + model_type: zModelType, + submodel: zSubModelType.optional(), +}); +export type ModelInfo = z.infer; + +export const zLoRAModelField = zModelIdentifier; +export type LoRAModelField = z.infer; + +export const zControlNetModelField = zModelIdentifier; +export type ControlNetModelField = z.infer; + +export const zIPAdapterModelField = zModelIdentifier; +export type IPAdapterModelField = z.infer; + +export const zT2IAdapterModelField = zModelIdentifier; +export type T2IAdapterModelField = z.infer; + +export const zLoraInfo = zModelInfo.extend({ + weight: z.number().optional(), +}); +export type LoraInfo = z.infer; + +export const zUNetField = z.object({ + unet: zModelInfo, + scheduler: zModelInfo, + loras: z.array(zLoraInfo), +}); +export type UNetField = z.infer; + +export const zCLIPField = z.object({ + tokenizer: zModelInfo, + text_encoder: zModelInfo, + skipped_layers: z.number(), + loras: z.array(zLoraInfo), +}); +export type CLIPField = z.infer; + +export const zVAEField = z.object({ + vae: zModelInfo, +}); +export type VAEField = z.infer; +// #endregion + +// #region Control Adapters +export const zControlField = z.object({ + image: zImageField, + control_model: zControlNetModelField, + control_weight: z.union([z.number(), z.array(z.number())]).optional(), + begin_step_percent: z.number().optional(), + end_step_percent: z.number().optional(), + control_mode: z.enum(['balanced', 'more_prompt', 'more_control', 'unbalanced']).optional(), + resize_mode: z.enum(['just_resize', 'crop_resize', 'fill_resize', 'just_resize_simple']).optional(), +}); +export type ControlField = z.infer; + +export const zIPAdapterField = z.object({ + image: zImageField, + ip_adapter_model: zIPAdapterModelField, + weight: z.number(), + begin_step_percent: z.number().optional(), + end_step_percent: z.number().optional(), +}); +export type IPAdapterField = z.infer; + +export const zT2IAdapterField = z.object({ + image: zImageField, + t2i_adapter_model: zT2IAdapterModelField, + weight: z.union([z.number(), z.array(z.number())]).optional(), + begin_step_percent: z.number().optional(), + end_step_percent: z.number().optional(), + resize_mode: z.enum(['just_resize', 'crop_resize', 'fill_resize', 'just_resize_simple']).optional(), +}); +export type T2IAdapterField = z.infer; +// #endregion + +// #region ProgressImage +export const zProgressImage = z.object({ + dataURL: z.string(), + width: z.number().int(), + height: z.number().int(), +}); +export type ProgressImage = z.infer; +// #endregion + +// #region ImageOutput +export const zImageOutput = z.object({ + image: zImageField, + width: z.number().int().gt(0), + height: z.number().int().gt(0), + type: z.literal('image_output'), +}); +export type ImageOutput = z.infer; +export const isImageOutput = (output: unknown): output is ImageOutput => zImageOutput.safeParse(output).success; +// #endregion diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/constants.ts b/invokeai/frontend/web/src/features/nodes/types/v2/constants.ts new file mode 100644 index 00000000000..35ef9e9fd2c --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/constants.ts @@ -0,0 +1,80 @@ +import type { Node } from 'reactflow'; + +/** + * How long to wait before showing a tooltip when hovering a field handle. + */ +export const HANDLE_TOOLTIP_OPEN_DELAY = 500; + +/** + * The width of a node in the UI in pixels. + */ +export const NODE_WIDTH = 320; + +/** + * This class name is special - reactflow uses it to identify the drag handle of a node, + * applying the appropriate listeners to it. + */ +export const DRAG_HANDLE_CLASSNAME = 'node-drag-handle'; + +/** + * reactflow-specifc properties shared between all node types. + */ +export const SHARED_NODE_PROPERTIES: Partial = { + dragHandle: `.${DRAG_HANDLE_CLASSNAME}`, +}; + +/** + * Helper for getting the kind of a field. + */ +export const KIND_MAP = { + input: 'inputs' as const, + output: 'outputs' as const, +}; + +/** + * Model types' handles are rendered as squares in the UI. + */ +export const MODEL_TYPES = [ + 'IPAdapterModelField', + 'ControlNetModelField', + 'LoRAModelField', + 'MainModelField', + 'SDXLMainModelField', + 'SDXLRefinerModelField', + 'VaeModelField', + 'UNetField', + 'VaeField', + 'ClipField', + 'T2IAdapterModelField', + 'IPAdapterModelField', +]; + +/** + * Colors for each field type - applies to their handles and edges. + */ +export const FIELD_COLORS: { [key: string]: string } = { + BoardField: 'purple.500', + BooleanField: 'green.500', + ClipField: 'green.500', + ColorField: 'pink.300', + ConditioningField: 'cyan.500', + ControlField: 'teal.500', + ControlNetModelField: 'teal.500', + EnumField: 'blue.500', + FloatField: 'orange.500', + ImageField: 'purple.500', + IntegerField: 'red.500', + IPAdapterField: 'teal.500', + IPAdapterModelField: 'teal.500', + LatentsField: 'pink.500', + LoRAModelField: 'teal.500', + MainModelField: 'teal.500', + SDXLMainModelField: 'teal.500', + SDXLRefinerModelField: 'teal.500', + StringField: 'yellow.500', + T2IAdapterField: 'teal.500', + T2IAdapterModelField: 'teal.500', + UNetField: 'red.500', + VaeField: 'blue.500', + VaeModelField: 'teal.500', +}; diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/error.ts b/invokeai/frontend/web/src/features/nodes/types/v2/error.ts new file mode 100644 index 00000000000..905b487fb04 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/error.ts @@ -0,0 +1,58 @@ +/** + * Invalid Workflow Version Error + * Raised when a workflow version is not recognized. + */ +export class WorkflowVersionError extends Error { + /** + * Create WorkflowVersionError + * @param {String} message + */ + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} +/** + * Workflow Migration Error + * Raised when a workflow migration fails. + */ +export class WorkflowMigrationError extends Error { + /** + * Create WorkflowMigrationError + * @param {String} message + */ + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} + +/** + * Unable to Update Node Error + * Raised when a node cannot be updated. + */ +export class NodeUpdateError extends Error { + /** + * Create NodeUpdateError + * @param {String} message + */ + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} + +/** + * FieldParseError + * Raised when a field cannot be parsed from a field schema. + */ +export class FieldParseError extends Error { + /** + * Create FieldTypeParseError + * @param {String} message + */ + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/field.ts b/invokeai/frontend/web/src/features/nodes/types/v2/field.ts new file mode 100644 index 00000000000..38f1af55dd8 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/field.ts @@ -0,0 +1,875 @@ +import { z } from 'zod'; + +import { + zBoardField, + zColorField, + zControlNetModelField, + zImageField, + zIPAdapterModelField, + zLoRAModelField, + zMainModelField, + zSchedulerField, + zT2IAdapterModelField, + zVAEModelField, +} from './common'; + +/** + * zod schemas & inferred types for fields. + * + * These schemas and types are only required for stateful field - fields that have UI components + * and allow the user to directly provide values. + * + * This includes primitive values (numbers, strings, booleans), models, scheduler, etc. + * + * If a field type does not have a UI component, then it does not need to be included here, because + * we never store its value. Such field types will be handled via the "StatelessField" logic. + * + * Fields require: + * - zFieldType - zod schema for the field type + * - zFieldValue - zod schema for the field value + * - zFieldInputInstance - zod schema for the field's input instance + * - zFieldOutputInstance - zod schema for the field's output instance + * - zFieldInputTemplate - zod schema for the field's input template + * - zFieldOutputTemplate - zod schema for the field's output template + * - inferred types for each schema + * - type guards for InputInstance and InputTemplate + * + * These then must be added to the unions at the bottom of this file. + */ + +/** */ + +// #region Base schemas & misc +export const zFieldInput = z.enum(['connection', 'direct', 'any']); +export type FieldInput = z.infer; + +export const zFieldUIComponent = z.enum(['none', 'textarea', 'slider']); +export type FieldUIComponent = z.infer; + +export const zFieldInstanceBase = z.object({ + id: z.string().trim().min(1), + name: z.string().trim().min(1), +}); +export const zFieldInputInstanceBase = zFieldInstanceBase.extend({ + fieldKind: z.literal('input'), + label: z.string().nullish(), +}); +export const zFieldOutputInstanceBase = zFieldInstanceBase.extend({ + fieldKind: z.literal('output'), +}); +export type FieldInstanceBase = z.infer; +export type FieldInputInstanceBase = z.infer; +export type FieldOutputInstanceBase = z.infer; + +export const zFieldTemplateBase = z.object({ + name: z.string().min(1), + title: z.string().min(1), + description: z.string().nullish(), + ui_hidden: z.boolean(), + ui_type: z.string().nullish(), + ui_order: z.number().int().nullish(), +}); +export const zFieldInputTemplateBase = zFieldTemplateBase.extend({ + fieldKind: z.literal('input'), + input: zFieldInput, + required: z.boolean(), + ui_component: zFieldUIComponent.nullish(), + ui_choice_labels: z.record(z.string()).nullish(), +}); +export const zFieldOutputTemplateBase = zFieldTemplateBase.extend({ + fieldKind: z.literal('output'), +}); +export type FieldTemplateBase = z.infer; +export type FieldInputTemplateBase = z.infer; +export type FieldOutputTemplateBase = z.infer; + +export const zFieldTypeBase = z.object({ + isCollection: z.boolean(), + isCollectionOrScalar: z.boolean(), +}); + +export const zFieldIdentifier = z.object({ + nodeId: z.string().trim().min(1), + fieldName: z.string().trim().min(1), +}); +export type FieldIdentifier = z.infer; +export const isFieldIdentifier = (val: unknown): val is FieldIdentifier => zFieldIdentifier.safeParse(val).success; +// #endregion + +// #region IntegerField +export const zIntegerFieldType = zFieldTypeBase.extend({ + name: z.literal('IntegerField'), +}); +export const zIntegerFieldValue = z.number().int(); +export const zIntegerFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zIntegerFieldType, + value: zIntegerFieldValue, +}); +export const zIntegerFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zIntegerFieldType, +}); +export const zIntegerFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zIntegerFieldType, + default: zIntegerFieldValue, + multipleOf: z.number().int().optional(), + maximum: z.number().int().optional(), + exclusiveMaximum: z.number().int().optional(), + minimum: z.number().int().optional(), + exclusiveMinimum: z.number().int().optional(), +}); +export const zIntegerFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zIntegerFieldType, +}); +export type IntegerFieldType = z.infer; +export type IntegerFieldValue = z.infer; +export type IntegerFieldInputInstance = z.infer; +export type IntegerFieldInputTemplate = z.infer; +export const isIntegerFieldInputInstance = (val: unknown): val is IntegerFieldInputInstance => + zIntegerFieldInputInstance.safeParse(val).success; +export const isIntegerFieldInputTemplate = (val: unknown): val is IntegerFieldInputTemplate => + zIntegerFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region FloatField +export const zFloatFieldType = zFieldTypeBase.extend({ + name: z.literal('FloatField'), +}); +export const zFloatFieldValue = z.number(); +export const zFloatFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zFloatFieldType, + value: zFloatFieldValue, +}); +export const zFloatFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zFloatFieldType, +}); +export const zFloatFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zFloatFieldType, + default: zFloatFieldValue, + multipleOf: z.number().optional(), + maximum: z.number().optional(), + exclusiveMaximum: z.number().optional(), + minimum: z.number().optional(), + exclusiveMinimum: z.number().optional(), +}); +export const zFloatFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zFloatFieldType, +}); +export type FloatFieldType = z.infer; +export type FloatFieldValue = z.infer; +export type FloatFieldInputInstance = z.infer; +export type FloatFieldOutputInstance = z.infer; +export type FloatFieldInputTemplate = z.infer; +export type FloatFieldOutputTemplate = z.infer; +export const isFloatFieldInputInstance = (val: unknown): val is FloatFieldInputInstance => + zFloatFieldInputInstance.safeParse(val).success; +export const isFloatFieldInputTemplate = (val: unknown): val is FloatFieldInputTemplate => + zFloatFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region StringField +export const zStringFieldType = zFieldTypeBase.extend({ + name: z.literal('StringField'), +}); +export const zStringFieldValue = z.string(); +export const zStringFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zStringFieldType, + value: zStringFieldValue, +}); +export const zStringFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zStringFieldType, +}); +export const zStringFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zStringFieldType, + default: zStringFieldValue, + maxLength: z.number().int().optional(), + minLength: z.number().int().optional(), +}); +export const zStringFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zStringFieldType, +}); + +export type StringFieldType = z.infer; +export type StringFieldValue = z.infer; +export type StringFieldInputInstance = z.infer; +export type StringFieldOutputInstance = z.infer; +export type StringFieldInputTemplate = z.infer; +export type StringFieldOutputTemplate = z.infer; +export const isStringFieldInputInstance = (val: unknown): val is StringFieldInputInstance => + zStringFieldInputInstance.safeParse(val).success; +export const isStringFieldInputTemplate = (val: unknown): val is StringFieldInputTemplate => + zStringFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region BooleanField +export const zBooleanFieldType = zFieldTypeBase.extend({ + name: z.literal('BooleanField'), +}); +export const zBooleanFieldValue = z.boolean(); +export const zBooleanFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zBooleanFieldType, + value: zBooleanFieldValue, +}); +export const zBooleanFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zBooleanFieldType, +}); +export const zBooleanFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zBooleanFieldType, + default: zBooleanFieldValue, +}); +export const zBooleanFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zBooleanFieldType, +}); +export type BooleanFieldType = z.infer; +export type BooleanFieldValue = z.infer; +export type BooleanFieldInputInstance = z.infer; +export type BooleanFieldOutputInstance = z.infer; +export type BooleanFieldInputTemplate = z.infer; +export type BooleanFieldOutputTemplate = z.infer; +export const isBooleanFieldInputInstance = (val: unknown): val is BooleanFieldInputInstance => + zBooleanFieldInputInstance.safeParse(val).success; +export const isBooleanFieldInputTemplate = (val: unknown): val is BooleanFieldInputTemplate => + zBooleanFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region EnumField +export const zEnumFieldType = zFieldTypeBase.extend({ + name: z.literal('EnumField'), +}); +export const zEnumFieldValue = z.string(); +export const zEnumFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zEnumFieldType, + value: zEnumFieldValue, +}); +export const zEnumFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zEnumFieldType, +}); +export const zEnumFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zEnumFieldType, + default: zEnumFieldValue, + options: z.array(z.string()), + labels: z.record(z.string()).optional(), +}); +export const zEnumFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zEnumFieldType, +}); +export type EnumFieldType = z.infer; +export type EnumFieldValue = z.infer; +export type EnumFieldInputInstance = z.infer; +export type EnumFieldOutputInstance = z.infer; +export type EnumFieldInputTemplate = z.infer; +export type EnumFieldOutputTemplate = z.infer; +export const isEnumFieldInputInstance = (val: unknown): val is EnumFieldInputInstance => + zEnumFieldInputInstance.safeParse(val).success; +export const isEnumFieldInputTemplate = (val: unknown): val is EnumFieldInputTemplate => + zEnumFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region ImageField +export const zImageFieldType = zFieldTypeBase.extend({ + name: z.literal('ImageField'), +}); +export const zImageFieldValue = zImageField.optional(); +export const zImageFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zImageFieldType, + value: zImageFieldValue, +}); +export const zImageFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zImageFieldType, +}); +export const zImageFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zImageFieldType, + default: zImageFieldValue, +}); +export const zImageFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zImageFieldType, +}); +export type ImageFieldType = z.infer; +export type ImageFieldValue = z.infer; +export type ImageFieldInputInstance = z.infer; +export type ImageFieldOutputInstance = z.infer; +export type ImageFieldInputTemplate = z.infer; +export type ImageFieldOutputTemplate = z.infer; +export const isImageFieldInputInstance = (val: unknown): val is ImageFieldInputInstance => + zImageFieldInputInstance.safeParse(val).success; +export const isImageFieldInputTemplate = (val: unknown): val is ImageFieldInputTemplate => + zImageFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region BoardField +export const zBoardFieldType = zFieldTypeBase.extend({ + name: z.literal('BoardField'), +}); +export const zBoardFieldValue = zBoardField.optional(); +export const zBoardFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zBoardFieldType, + value: zBoardFieldValue, +}); +export const zBoardFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zBoardFieldType, +}); +export const zBoardFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zBoardFieldType, + default: zBoardFieldValue, +}); +export const zBoardFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zBoardFieldType, +}); +export type BoardFieldType = z.infer; +export type BoardFieldValue = z.infer; +export type BoardFieldInputInstance = z.infer; +export type BoardFieldOutputInstance = z.infer; +export type BoardFieldInputTemplate = z.infer; +export type BoardFieldOutputTemplate = z.infer; +export const isBoardFieldInputInstance = (val: unknown): val is BoardFieldInputInstance => + zBoardFieldInputInstance.safeParse(val).success; +export const isBoardFieldInputTemplate = (val: unknown): val is BoardFieldInputTemplate => + zBoardFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region ColorField +export const zColorFieldType = zFieldTypeBase.extend({ + name: z.literal('ColorField'), +}); +export const zColorFieldValue = zColorField.optional(); +export const zColorFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zColorFieldType, + value: zColorFieldValue, +}); +export const zColorFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zColorFieldType, +}); +export const zColorFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zColorFieldType, + default: zColorFieldValue, +}); +export const zColorFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zColorFieldType, +}); +export type ColorFieldType = z.infer; +export type ColorFieldValue = z.infer; +export type ColorFieldInputInstance = z.infer; +export type ColorFieldOutputInstance = z.infer; +export type ColorFieldInputTemplate = z.infer; +export type ColorFieldOutputTemplate = z.infer; +export const isColorFieldInputInstance = (val: unknown): val is ColorFieldInputInstance => + zColorFieldInputInstance.safeParse(val).success; +export const isColorFieldInputTemplate = (val: unknown): val is ColorFieldInputTemplate => + zColorFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region MainModelField +export const zMainModelFieldType = zFieldTypeBase.extend({ + name: z.literal('MainModelField'), +}); +export const zMainModelFieldValue = zMainModelField.optional(); +export const zMainModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zMainModelFieldType, + value: zMainModelFieldValue, +}); +export const zMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zMainModelFieldType, +}); +export const zMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zMainModelFieldType, + default: zMainModelFieldValue, +}); +export const zMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zMainModelFieldType, +}); +export type MainModelFieldType = z.infer; +export type MainModelFieldValue = z.infer; +export type MainModelFieldInputInstance = z.infer; +export type MainModelFieldOutputInstance = z.infer; +export type MainModelFieldInputTemplate = z.infer; +export type MainModelFieldOutputTemplate = z.infer; +export const isMainModelFieldInputInstance = (val: unknown): val is MainModelFieldInputInstance => + zMainModelFieldInputInstance.safeParse(val).success; +export const isMainModelFieldInputTemplate = (val: unknown): val is MainModelFieldInputTemplate => + zMainModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region SDXLMainModelField +export const zSDXLMainModelFieldType = zFieldTypeBase.extend({ + name: z.literal('SDXLMainModelField'), +}); +export const zSDXLMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only. +export const zSDXLMainModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zSDXLMainModelFieldType, + value: zSDXLMainModelFieldValue, +}); +export const zSDXLMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zSDXLMainModelFieldType, +}); +export const zSDXLMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zSDXLMainModelFieldType, + default: zSDXLMainModelFieldValue, +}); +export const zSDXLMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zSDXLMainModelFieldType, +}); +export type SDXLMainModelFieldType = z.infer; +export type SDXLMainModelFieldValue = z.infer; +export type SDXLMainModelFieldInputInstance = z.infer; +export type SDXLMainModelFieldOutputInstance = z.infer; +export type SDXLMainModelFieldInputTemplate = z.infer; +export type SDXLMainModelFieldOutputTemplate = z.infer; +export const isSDXLMainModelFieldInputInstance = (val: unknown): val is SDXLMainModelFieldInputInstance => + zSDXLMainModelFieldInputInstance.safeParse(val).success; +export const isSDXLMainModelFieldInputTemplate = (val: unknown): val is SDXLMainModelFieldInputTemplate => + zSDXLMainModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region SDXLRefinerModelField +export const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({ + name: z.literal('SDXLRefinerModelField'), +}); +export const zSDXLRefinerModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL Refiner models only. +export const zSDXLRefinerModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zSDXLRefinerModelFieldType, + value: zSDXLRefinerModelFieldValue, +}); +export const zSDXLRefinerModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zSDXLRefinerModelFieldType, +}); +export const zSDXLRefinerModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zSDXLRefinerModelFieldType, + default: zSDXLRefinerModelFieldValue, +}); +export const zSDXLRefinerModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zSDXLRefinerModelFieldType, +}); +export type SDXLRefinerModelFieldType = z.infer; +export type SDXLRefinerModelFieldValue = z.infer; +export type SDXLRefinerModelFieldInputInstance = z.infer; +export type SDXLRefinerModelFieldOutputInstance = z.infer; +export type SDXLRefinerModelFieldInputTemplate = z.infer; +export type SDXLRefinerModelFieldOutputTemplate = z.infer; +export const isSDXLRefinerModelFieldInputInstance = (val: unknown): val is SDXLRefinerModelFieldInputInstance => + zSDXLRefinerModelFieldInputInstance.safeParse(val).success; +export const isSDXLRefinerModelFieldInputTemplate = (val: unknown): val is SDXLRefinerModelFieldInputTemplate => + zSDXLRefinerModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region VAEModelField +export const zVAEModelFieldType = zFieldTypeBase.extend({ + name: z.literal('VAEModelField'), +}); +export const zVAEModelFieldValue = zVAEModelField.optional(); +export const zVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zVAEModelFieldType, + value: zVAEModelFieldValue, +}); +export const zVAEModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zVAEModelFieldType, +}); +export const zVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zVAEModelFieldType, + default: zVAEModelFieldValue, +}); +export const zVAEModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zVAEModelFieldType, +}); +export type VAEModelFieldType = z.infer; +export type VAEModelFieldValue = z.infer; +export type VAEModelFieldInputInstance = z.infer; +export type VAEModelFieldOutputInstance = z.infer; +export type VAEModelFieldInputTemplate = z.infer; +export type VAEModelFieldOutputTemplate = z.infer; +export const isVAEModelFieldInputInstance = (val: unknown): val is VAEModelFieldInputInstance => + zVAEModelFieldInputInstance.safeParse(val).success; +export const isVAEModelFieldInputTemplate = (val: unknown): val is VAEModelFieldInputTemplate => + zVAEModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region LoRAModelField +export const zLoRAModelFieldType = zFieldTypeBase.extend({ + name: z.literal('LoRAModelField'), +}); +export const zLoRAModelFieldValue = zLoRAModelField.optional(); +export const zLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zLoRAModelFieldType, + value: zLoRAModelFieldValue, +}); +export const zLoRAModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zLoRAModelFieldType, +}); +export const zLoRAModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zLoRAModelFieldType, + default: zLoRAModelFieldValue, +}); +export const zLoRAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zLoRAModelFieldType, +}); +export type LoRAModelFieldType = z.infer; +export type LoRAModelFieldValue = z.infer; +export type LoRAModelFieldInputInstance = z.infer; +export type LoRAModelFieldOutputInstance = z.infer; +export type LoRAModelFieldInputTemplate = z.infer; +export type LoRAModelFieldOutputTemplate = z.infer; +export const isLoRAModelFieldInputInstance = (val: unknown): val is LoRAModelFieldInputInstance => + zLoRAModelFieldInputInstance.safeParse(val).success; +export const isLoRAModelFieldInputTemplate = (val: unknown): val is LoRAModelFieldInputTemplate => + zLoRAModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region ControlNetModelField +export const zControlNetModelFieldType = zFieldTypeBase.extend({ + name: z.literal('ControlNetModelField'), +}); +export const zControlNetModelFieldValue = zControlNetModelField.optional(); +export const zControlNetModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zControlNetModelFieldType, + value: zControlNetModelFieldValue, +}); +export const zControlNetModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zControlNetModelFieldType, +}); +export const zControlNetModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zControlNetModelFieldType, + default: zControlNetModelFieldValue, +}); +export const zControlNetModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zControlNetModelFieldType, +}); +export type ControlNetModelFieldType = z.infer; +export type ControlNetModelFieldValue = z.infer; +export type ControlNetModelFieldInputInstance = z.infer; +export type ControlNetModelFieldOutputInstance = z.infer; +export type ControlNetModelFieldInputTemplate = z.infer; +export type ControlNetModelFieldOutputTemplate = z.infer; +export const isControlNetModelFieldInputInstance = (val: unknown): val is ControlNetModelFieldInputInstance => + zControlNetModelFieldInputInstance.safeParse(val).success; +export const isControlNetModelFieldInputTemplate = (val: unknown): val is ControlNetModelFieldInputTemplate => + zControlNetModelFieldInputTemplate.safeParse(val).success; +export const isControlNetModelFieldValue = (v: unknown): v is ControlNetModelFieldValue => + zControlNetModelFieldValue.safeParse(v).success; +// #endregion + +// #region IPAdapterModelField +export const zIPAdapterModelFieldType = zFieldTypeBase.extend({ + name: z.literal('IPAdapterModelField'), +}); +export const zIPAdapterModelFieldValue = zIPAdapterModelField.optional(); +export const zIPAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zIPAdapterModelFieldType, + value: zIPAdapterModelFieldValue, +}); +export const zIPAdapterModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zIPAdapterModelFieldType, +}); +export const zIPAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zIPAdapterModelFieldType, + default: zIPAdapterModelFieldValue, +}); +export const zIPAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zIPAdapterModelFieldType, +}); +export type IPAdapterModelFieldType = z.infer; +export type IPAdapterModelFieldValue = z.infer; +export type IPAdapterModelFieldInputInstance = z.infer; +export type IPAdapterModelFieldOutputInstance = z.infer; +export type IPAdapterModelFieldInputTemplate = z.infer; +export type IPAdapterModelFieldOutputTemplate = z.infer; +export const isIPAdapterModelFieldInputInstance = (val: unknown): val is IPAdapterModelFieldInputInstance => + zIPAdapterModelFieldInputInstance.safeParse(val).success; +export const isIPAdapterModelFieldInputTemplate = (val: unknown): val is IPAdapterModelFieldInputTemplate => + zIPAdapterModelFieldInputTemplate.safeParse(val).success; +export const isIPAdapterModelFieldValue = (val: unknown): val is IPAdapterModelFieldValue => + zIPAdapterModelFieldValue.safeParse(val).success; +// #endregion + +// #region T2IAdapterField +export const zT2IAdapterModelFieldType = zFieldTypeBase.extend({ + name: z.literal('T2IAdapterModelField'), +}); +export const zT2IAdapterModelFieldValue = zT2IAdapterModelField.optional(); +export const zT2IAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zT2IAdapterModelFieldType, + value: zT2IAdapterModelFieldValue, +}); +export const zT2IAdapterModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zT2IAdapterModelFieldType, +}); +export const zT2IAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zT2IAdapterModelFieldType, + default: zT2IAdapterModelFieldValue, +}); +export const zT2IAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zT2IAdapterModelFieldType, +}); +export type T2IAdapterModelFieldType = z.infer; +export type T2IAdapterModelFieldValue = z.infer; +export type T2IAdapterModelFieldInputInstance = z.infer; +export type T2IAdapterModelFieldOutputInstance = z.infer; +export type T2IAdapterModelFieldInputTemplate = z.infer; +export type T2IAdapterModelFieldOutputTemplate = z.infer; +export const isT2IAdapterModelFieldInputInstance = (val: unknown): val is T2IAdapterModelFieldInputInstance => + zT2IAdapterModelFieldInputInstance.safeParse(val).success; +export const isT2IAdapterModelFieldInputTemplate = (val: unknown): val is T2IAdapterModelFieldInputTemplate => + zT2IAdapterModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region SchedulerField +export const zSchedulerFieldType = zFieldTypeBase.extend({ + name: z.literal('SchedulerField'), +}); +export const zSchedulerFieldValue = zSchedulerField.optional(); +export const zSchedulerFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zSchedulerFieldType, + value: zSchedulerFieldValue, +}); +export const zSchedulerFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zSchedulerFieldType, +}); +export const zSchedulerFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zSchedulerFieldType, + default: zSchedulerFieldValue, +}); +export const zSchedulerFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zSchedulerFieldType, +}); +export type SchedulerFieldType = z.infer; +export type SchedulerFieldValue = z.infer; +export type SchedulerFieldInputInstance = z.infer; +export type SchedulerFieldOutputInstance = z.infer; +export type SchedulerFieldInputTemplate = z.infer; +export type SchedulerFieldOutputTemplate = z.infer; +export const isSchedulerFieldInputInstance = (val: unknown): val is SchedulerFieldInputInstance => + zSchedulerFieldInputInstance.safeParse(val).success; +export const isSchedulerFieldInputTemplate = (val: unknown): val is SchedulerFieldInputTemplate => + zSchedulerFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region StatelessField +/** + * StatelessField is a catchall for stateless fields with no UI input components. They do not + * do not support "direct" input, instead only accepting connections from other fields. + * + * This field type serves as a "generic" field type. + * + * Examples include: + * - Fields like UNetField or LatentsField where we do not allow direct UI input + * - Reserved fields like IsIntermediate + * - Any other field we don't have full-on schemas for + */ +export const zStatelessFieldType = zFieldTypeBase.extend({ + name: z.string().min(1), // stateless --> we accept the field's name as the type +}); +export const zStatelessFieldValue = z.undefined().catch(undefined); // stateless --> no value, but making this z.never() introduces a lot of extra TS fanagling +export const zStatelessFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zStatelessFieldType, + value: zStatelessFieldValue, +}); +export const zStatelessFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zStatelessFieldType, +}); +export const zStatelessFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zStatelessFieldType, + default: zStatelessFieldValue, + input: z.literal('connection'), // stateless --> only accepts connection inputs +}); +export const zStatelessFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zStatelessFieldType, +}); + +export type StatelessFieldType = z.infer; +export type StatelessFieldValue = z.infer; +export type StatelessFieldInputInstance = z.infer; +export type StatelessFieldOutputInstance = z.infer; +export type StatelessFieldInputTemplate = z.infer; +export type StatelessFieldOutputTemplate = z.infer; +// #endregion + +/** + * Here we define the main field unions: + * - FieldType + * - FieldValue + * - FieldInputInstance + * - FieldOutputInstance + * - FieldInputTemplate + * - FieldOutputTemplate + * + * All stateful fields are unioned together, and then that union is unioned with StatelessField. + * + * This allows us to interact with stateful fields without needing to worry about "generic" handling + * for all other StatelessFields. + */ + +// #region StatefulFieldType & FieldType +export const zStatefulFieldType = z.union([ + zIntegerFieldType, + zFloatFieldType, + zStringFieldType, + zBooleanFieldType, + zEnumFieldType, + zImageFieldType, + zBoardFieldType, + zMainModelFieldType, + zSDXLMainModelFieldType, + zSDXLRefinerModelFieldType, + zVAEModelFieldType, + zLoRAModelFieldType, + zControlNetModelFieldType, + zIPAdapterModelFieldType, + zT2IAdapterModelFieldType, + zColorFieldType, + zSchedulerFieldType, +]); +export type StatefulFieldType = z.infer; +export const isStatefulFieldType = (val: unknown): val is StatefulFieldType => + zStatefulFieldType.safeParse(val).success; + +export const zFieldType = z.union([zStatefulFieldType, zStatelessFieldType]); +export type FieldType = z.infer; +export const isFieldType = (val: unknown): val is FieldType => zFieldType.safeParse(val).success; +// #endregion + +// #region StatefulFieldValue & FieldValue +export const zStatefulFieldValue = z.union([ + zIntegerFieldValue, + zFloatFieldValue, + zStringFieldValue, + zBooleanFieldValue, + zEnumFieldValue, + zImageFieldValue, + zBoardFieldValue, + zMainModelFieldValue, + zSDXLMainModelFieldValue, + zSDXLRefinerModelFieldValue, + zVAEModelFieldValue, + zLoRAModelFieldValue, + zControlNetModelFieldValue, + zIPAdapterModelFieldValue, + zT2IAdapterModelFieldValue, + zColorFieldValue, + zSchedulerFieldValue, +]); +export type StatefulFieldValue = z.infer; +export const isStatefulFieldValue = (val: unknown): val is StatefulFieldValue => + zStatefulFieldValue.safeParse(val).success; + +export const zFieldValue = z.union([zStatefulFieldValue, zStatelessFieldValue]); +export type FieldValue = z.infer; +export const isFieldValue = (val: unknown): val is FieldValue => zFieldValue.safeParse(val).success; +// #endregion + +// #region StatefulFieldInputInstance & FieldInputInstance +export const zStatefulFieldInputInstance = z.union([ + zIntegerFieldInputInstance, + zFloatFieldInputInstance, + zStringFieldInputInstance, + zBooleanFieldInputInstance, + zEnumFieldInputInstance, + zImageFieldInputInstance, + zBoardFieldInputInstance, + zMainModelFieldInputInstance, + zSDXLMainModelFieldInputInstance, + zSDXLRefinerModelFieldInputInstance, + zVAEModelFieldInputInstance, + zLoRAModelFieldInputInstance, + zControlNetModelFieldInputInstance, + zIPAdapterModelFieldInputInstance, + zT2IAdapterModelFieldInputInstance, + zColorFieldInputInstance, + zSchedulerFieldInputInstance, +]); +export type StatefulFieldInputInstance = z.infer; +export const isStatefulFieldInputInstance = (val: unknown): val is StatefulFieldInputInstance => + zStatefulFieldInputInstance.safeParse(val).success; + +export const zFieldInputInstance = z.union([zStatefulFieldInputInstance, zStatelessFieldInputInstance]); +export type FieldInputInstance = z.infer; +export const isFieldInputInstance = (val: unknown): val is FieldInputInstance => + zFieldInputInstance.safeParse(val).success; +// #endregion + +// #region StatefulFieldOutputInstance & FieldOutputInstance +export const zStatefulFieldOutputInstance = z.union([ + zIntegerFieldOutputInstance, + zFloatFieldOutputInstance, + zStringFieldOutputInstance, + zBooleanFieldOutputInstance, + zEnumFieldOutputInstance, + zImageFieldOutputInstance, + zBoardFieldOutputInstance, + zMainModelFieldOutputInstance, + zSDXLMainModelFieldOutputInstance, + zSDXLRefinerModelFieldOutputInstance, + zVAEModelFieldOutputInstance, + zLoRAModelFieldOutputInstance, + zControlNetModelFieldOutputInstance, + zIPAdapterModelFieldOutputInstance, + zT2IAdapterModelFieldOutputInstance, + zColorFieldOutputInstance, + zSchedulerFieldOutputInstance, +]); +export type StatefulFieldOutputInstance = z.infer; +export const isStatefulFieldOutputInstance = (val: unknown): val is StatefulFieldOutputInstance => + zStatefulFieldOutputInstance.safeParse(val).success; + +export const zFieldOutputInstance = z.union([zStatefulFieldOutputInstance, zStatelessFieldOutputInstance]); +export type FieldOutputInstance = z.infer; +export const isFieldOutputInstance = (val: unknown): val is FieldOutputInstance => + zFieldOutputInstance.safeParse(val).success; +// #endregion + +// #region StatefulFieldInputTemplate & FieldInputTemplate +export const zStatefulFieldInputTemplate = z.union([ + zIntegerFieldInputTemplate, + zFloatFieldInputTemplate, + zStringFieldInputTemplate, + zBooleanFieldInputTemplate, + zEnumFieldInputTemplate, + zImageFieldInputTemplate, + zBoardFieldInputTemplate, + zMainModelFieldInputTemplate, + zSDXLMainModelFieldInputTemplate, + zSDXLRefinerModelFieldInputTemplate, + zVAEModelFieldInputTemplate, + zLoRAModelFieldInputTemplate, + zControlNetModelFieldInputTemplate, + zIPAdapterModelFieldInputTemplate, + zT2IAdapterModelFieldInputTemplate, + zColorFieldInputTemplate, + zSchedulerFieldInputTemplate, + zStatelessFieldInputTemplate, +]); +export type StatefulFieldInputTemplate = z.infer; +export const isStatefulFieldInputTemplate = (val: unknown): val is StatefulFieldInputTemplate => + zStatefulFieldInputTemplate.safeParse(val).success; + +export const zFieldInputTemplate = z.union([zStatefulFieldInputTemplate, zStatelessFieldInputTemplate]); +export type FieldInputTemplate = z.infer; +export const isFieldInputTemplate = (val: unknown): val is FieldInputTemplate => + zFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region StatefulFieldOutputTemplate & FieldOutputTemplate +export const zStatefulFieldOutputTemplate = z.union([ + zIntegerFieldOutputTemplate, + zFloatFieldOutputTemplate, + zStringFieldOutputTemplate, + zBooleanFieldOutputTemplate, + zEnumFieldOutputTemplate, + zImageFieldOutputTemplate, + zBoardFieldOutputTemplate, + zMainModelFieldOutputTemplate, + zSDXLMainModelFieldOutputTemplate, + zSDXLRefinerModelFieldOutputTemplate, + zVAEModelFieldOutputTemplate, + zLoRAModelFieldOutputTemplate, + zControlNetModelFieldOutputTemplate, + zIPAdapterModelFieldOutputTemplate, + zT2IAdapterModelFieldOutputTemplate, + zColorFieldOutputTemplate, + zSchedulerFieldOutputTemplate, +]); +export type StatefulFieldOutputTemplate = z.infer; +export const isStatefulFieldOutputTemplate = (val: unknown): val is StatefulFieldOutputTemplate => + zStatefulFieldOutputTemplate.safeParse(val).success; + +export const zFieldOutputTemplate = z.union([zStatefulFieldOutputTemplate, zStatelessFieldOutputTemplate]); +export type FieldOutputTemplate = z.infer; +export const isFieldOutputTemplate = (val: unknown): val is FieldOutputTemplate => + zFieldOutputTemplate.safeParse(val).success; +// #endregion diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/invocation.ts b/invokeai/frontend/web/src/features/nodes/types/v2/invocation.ts new file mode 100644 index 00000000000..86ec70fd9bd --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/invocation.ts @@ -0,0 +1,93 @@ +import type { Edge, Node } from 'reactflow'; +import { z } from 'zod'; + +import { zClassification, zProgressImage } from './common'; +import { zFieldInputInstance, zFieldInputTemplate, zFieldOutputInstance, zFieldOutputTemplate } from './field'; +import { zSemVer } from './semver'; + +// #region InvocationTemplate +export const zInvocationTemplate = z.object({ + type: z.string(), + title: z.string(), + description: z.string(), + tags: z.array(z.string().min(1)), + inputs: z.record(zFieldInputTemplate), + outputs: z.record(zFieldOutputTemplate), + outputType: z.string().min(1), + version: zSemVer, + useCache: z.boolean(), + nodePack: z.string().min(1).nullish(), + classification: zClassification, +}); +export type InvocationTemplate = z.infer; +// #endregion + +// #region NodeData +export const zInvocationNodeData = z.object({ + id: z.string().trim().min(1), + type: z.string().trim().min(1), + label: z.string(), + isOpen: z.boolean(), + notes: z.string(), + isIntermediate: z.boolean(), + useCache: z.boolean(), + version: zSemVer, + nodePack: z.string().min(1).nullish(), + inputs: z.record(zFieldInputInstance), + outputs: z.record(zFieldOutputInstance), +}); + +export const zNotesNodeData = z.object({ + id: z.string().trim().min(1), + type: z.literal('notes'), + label: z.string(), + isOpen: z.boolean(), + notes: z.string(), +}); +export const zCurrentImageNodeData = z.object({ + id: z.string().trim().min(1), + type: z.literal('current_image'), + label: z.string(), + isOpen: z.boolean(), +}); +export const zAnyNodeData = z.union([zInvocationNodeData, zNotesNodeData, zCurrentImageNodeData]); + +export type NotesNodeData = z.infer; +export type InvocationNodeData = z.infer; +export type CurrentImageNodeData = z.infer; +export type AnyNodeData = z.infer; + +export type InvocationNode = Node; +export type NotesNode = Node; +export type CurrentImageNode = Node; +export type AnyNode = Node; + +export const isInvocationNode = (node?: AnyNode): node is InvocationNode => Boolean(node && node.type === 'invocation'); +export const isNotesNode = (node?: AnyNode): node is NotesNode => Boolean(node && node.type === 'notes'); +export const isCurrentImageNode = (node?: AnyNode): node is CurrentImageNode => + Boolean(node && node.type === 'current_image'); +export const isInvocationNodeData = (node?: AnyNodeData): node is InvocationNodeData => + Boolean(node && !['notes', 'current_image'].includes(node.type)); // node.type may be 'notes', 'current_image', or any invocation type +// #endregion + +// #region NodeExecutionState +export const zNodeStatus = z.enum(['PENDING', 'IN_PROGRESS', 'COMPLETED', 'FAILED']); +export const zNodeExecutionState = z.object({ + nodeId: z.string().trim().min(1), + status: zNodeStatus, + progress: z.number().nullable(), + progressImage: zProgressImage.nullable(), + error: z.string().nullable(), + outputs: z.array(z.any()), +}); +export type NodeExecutionState = z.infer; +export type NodeStatus = z.infer; +// #endregion + +// #region Edges +export const zInvocationNodeEdgeExtra = z.object({ + type: z.union([z.literal('default'), z.literal('collapsed')]), +}); +export type InvocationNodeEdgeExtra = z.infer; +export type InvocationNodeEdge = Edge; +// #endregion diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/metadata.ts b/invokeai/frontend/web/src/features/nodes/types/v2/metadata.ts new file mode 100644 index 00000000000..0cc30499e38 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/metadata.ts @@ -0,0 +1,77 @@ +import { z } from 'zod'; + +import { + zControlField, + zIPAdapterField, + zLoRAModelField, + zMainModelField, + zSDXLRefinerModelField, + zT2IAdapterField, + zVAEModelField, +} from './common'; + +// #region Metadata-optimized versions of schemas +// TODO: It's possible that `deepPartial` will be deprecated: +// - https://github.com/colinhacks/zod/issues/2106 +// - https://github.com/colinhacks/zod/issues/2854 +export const zLoRAMetadataItem = z.object({ + lora: zLoRAModelField.deepPartial(), + weight: z.number(), +}); +const zControlNetMetadataItem = zControlField.deepPartial(); +const zIPAdapterMetadataItem = zIPAdapterField.deepPartial(); +const zT2IAdapterMetadataItem = zT2IAdapterField.deepPartial(); +const zSDXLRefinerModelMetadataItem = zSDXLRefinerModelField.deepPartial(); +const zModelMetadataItem = zMainModelField.deepPartial(); +const zVAEModelMetadataItem = zVAEModelField.deepPartial(); +export type LoRAMetadataItem = z.infer; +export type ControlNetMetadataItem = z.infer; +export type IPAdapterMetadataItem = z.infer; +export type T2IAdapterMetadataItem = z.infer; +export type SDXLRefinerModelMetadataItem = z.infer; +export type ModelMetadataItem = z.infer; +export type VAEModelMetadataItem = z.infer; +// #endregion + +// #region CoreMetadata +export const zCoreMetadata = z + .object({ + app_version: z.string().nullish().catch(null), + generation_mode: z.string().nullish().catch(null), + created_by: z.string().nullish().catch(null), + positive_prompt: z.string().nullish().catch(null), + negative_prompt: z.string().nullish().catch(null), + width: z.number().int().nullish().catch(null), + height: z.number().int().nullish().catch(null), + seed: z.number().int().nullish().catch(null), + rand_device: z.string().nullish().catch(null), + cfg_scale: z.number().nullish().catch(null), + cfg_rescale_multiplier: z.number().nullish().catch(null), + steps: z.number().int().nullish().catch(null), + scheduler: z.string().nullish().catch(null), + clip_skip: z.number().int().nullish().catch(null), + model: zModelMetadataItem.nullish().catch(null), + controlnets: z.array(zControlNetMetadataItem).nullish().catch(null), + ipAdapters: z.array(zIPAdapterMetadataItem).nullish().catch(null), + t2iAdapters: z.array(zT2IAdapterMetadataItem).nullish().catch(null), + loras: z.array(zLoRAMetadataItem).nullish().catch(null), + vae: zVAEModelMetadataItem.nullish().catch(null), + strength: z.number().nullish().catch(null), + hrf_enabled: z.boolean().nullish().catch(null), + hrf_strength: z.number().nullish().catch(null), + hrf_method: z.string().nullish().catch(null), + init_image: z.string().nullish().catch(null), + positive_style_prompt: z.string().nullish().catch(null), + negative_style_prompt: z.string().nullish().catch(null), + refiner_model: zSDXLRefinerModelMetadataItem.nullish().catch(null), + refiner_cfg_scale: z.number().nullish().catch(null), + refiner_steps: z.number().int().nullish().catch(null), + refiner_scheduler: z.string().nullish().catch(null), + refiner_positive_aesthetic_score: z.number().nullish().catch(null), + refiner_negative_aesthetic_score: z.number().nullish().catch(null), + refiner_start: z.number().nullish().catch(null), + }) + .passthrough(); +export type CoreMetadata = z.infer; + +// #endregion diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/openapi.ts b/invokeai/frontend/web/src/features/nodes/types/v2/openapi.ts new file mode 100644 index 00000000000..83d774439a3 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/openapi.ts @@ -0,0 +1,86 @@ +import type { OpenAPIV3_1 } from 'openapi-types'; +import type { + InputFieldJSONSchemaExtra, + InvocationJSONSchemaExtra, + OutputFieldJSONSchemaExtra, +} from 'services/api/types'; + +// Janky customization of OpenAPI Schema :/ + +export type InvocationSchemaExtra = InvocationJSONSchemaExtra & { + output: OpenAPIV3_1.ReferenceObject; // the output of the invocation + title: string; + category?: string; + tags?: string[]; + version: string; + properties: Omit< + NonNullable & (InputFieldJSONSchemaExtra | OutputFieldJSONSchemaExtra), + 'type' + > & { + type: Omit & { + default: string; + }; + use_cache: Omit & { + default: boolean; + }; + }; +}; + +export type InvocationSchemaType = { + default: string; // the type of the invocation +}; + +export type InvocationBaseSchemaObject = Omit & + InvocationSchemaExtra; + +export type InvocationOutputSchemaObject = Omit & { + properties: OpenAPIV3_1.SchemaObject['properties'] & { + type: Omit & { + default: string; + }; + } & { + class: 'output'; + }; +}; + +export type InvocationFieldSchema = OpenAPIV3_1.SchemaObject & InputFieldJSONSchemaExtra; + +export type OpenAPIV3_1SchemaOrRef = OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject; + +export interface ArraySchemaObject extends InvocationBaseSchemaObject { + type: OpenAPIV3_1.ArraySchemaObjectType; + items: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject; +} +export interface NonArraySchemaObject extends InvocationBaseSchemaObject { + type?: OpenAPIV3_1.NonArraySchemaObjectType; +} + +export type InvocationSchemaObject = (ArraySchemaObject | NonArraySchemaObject) & { class: 'invocation' }; + +export const isSchemaObject = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined +): obj is OpenAPIV3_1.SchemaObject => Boolean(obj && !('$ref' in obj)); + +export const isArraySchemaObject = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined +): obj is OpenAPIV3_1.ArraySchemaObject => Boolean(obj && !('$ref' in obj) && obj.type === 'array'); + +export const isNonArraySchemaObject = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined +): obj is OpenAPIV3_1.NonArraySchemaObject => Boolean(obj && !('$ref' in obj) && obj.type !== 'array'); + +export const isRefObject = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined +): obj is OpenAPIV3_1.ReferenceObject => Boolean(obj && '$ref' in obj); + +export const isInvocationSchemaObject = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | InvocationSchemaObject +): obj is InvocationSchemaObject => 'class' in obj && obj.class === 'invocation'; + +export const isInvocationOutputSchemaObject = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | InvocationOutputSchemaObject +): obj is InvocationOutputSchemaObject => 'class' in obj && obj.class === 'output'; + +export const isInvocationFieldSchema = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject +): obj is InvocationFieldSchema => !('$ref' in obj); diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/semver.ts b/invokeai/frontend/web/src/features/nodes/types/v2/semver.ts new file mode 100644 index 00000000000..3ba330eac47 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/semver.ts @@ -0,0 +1,21 @@ +import { z } from 'zod'; + +// Schemas and types for working with semver + +const zVersionInt = z.coerce.number().int().min(0); + +export const zSemVer = z.string().refine((val) => { + const [major, minor, patch] = val.split('.'); + return ( + zVersionInt.safeParse(major).success && zVersionInt.safeParse(minor).success && zVersionInt.safeParse(patch).success + ); +}); + +export const zParsedSemver = zSemVer.transform((val) => { + const [major, minor, patch] = val.split('.'); + return { + major: Number(major), + minor: Number(minor), + patch: Number(patch), + }; +}); diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/workflow.ts b/invokeai/frontend/web/src/features/nodes/types/v2/workflow.ts new file mode 100644 index 00000000000..723a354013b --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/workflow.ts @@ -0,0 +1,89 @@ +import { z } from 'zod'; + +import { zFieldIdentifier } from './field'; +import { zInvocationNodeData, zNotesNodeData } from './invocation'; + +// #region Workflow misc +export const zXYPosition = z + .object({ + x: z.number(), + y: z.number(), + }) + .default({ x: 0, y: 0 }); +export type XYPosition = z.infer; + +export const zDimension = z.number().gt(0).nullish(); +export type Dimension = z.infer; + +export const zWorkflowCategory = z.enum(['user', 'default', 'project']); +export type WorkflowCategory = z.infer; +// #endregion + +// #region Workflow Nodes +export const zWorkflowInvocationNode = z.object({ + id: z.string().trim().min(1), + type: z.literal('invocation'), + data: zInvocationNodeData, + width: zDimension, + height: zDimension, + position: zXYPosition, +}); +export const zWorkflowNotesNode = z.object({ + id: z.string().trim().min(1), + type: z.literal('notes'), + data: zNotesNodeData, + width: zDimension, + height: zDimension, + position: zXYPosition, +}); +export const zWorkflowNode = z.union([zWorkflowInvocationNode, zWorkflowNotesNode]); + +export type WorkflowInvocationNode = z.infer; +export type WorkflowNotesNode = z.infer; +export type WorkflowNode = z.infer; + +export const isWorkflowInvocationNode = (val: unknown): val is WorkflowInvocationNode => + zWorkflowInvocationNode.safeParse(val).success; +// #endregion + +// #region Workflow Edges +export const zWorkflowEdgeBase = z.object({ + id: z.string().trim().min(1), + source: z.string().trim().min(1), + target: z.string().trim().min(1), +}); +export const zWorkflowEdgeDefault = zWorkflowEdgeBase.extend({ + type: z.literal('default'), + sourceHandle: z.string().trim().min(1), + targetHandle: z.string().trim().min(1), +}); +export const zWorkflowEdgeCollapsed = zWorkflowEdgeBase.extend({ + type: z.literal('collapsed'), +}); +export const zWorkflowEdge = z.union([zWorkflowEdgeDefault, zWorkflowEdgeCollapsed]); + +export type WorkflowEdgeDefault = z.infer; +export type WorkflowEdgeCollapsed = z.infer; +export type WorkflowEdge = z.infer; +// #endregion + +// #region Workflow +export const zWorkflowV2 = z.object({ + id: z.string().min(1).optional(), + name: z.string(), + author: z.string(), + description: z.string(), + version: z.string(), + contact: z.string(), + tags: z.string(), + notes: z.string(), + nodes: z.array(zWorkflowNode), + edges: z.array(zWorkflowEdge), + exposedFields: z.array(zFieldIdentifier), + meta: z.object({ + category: zWorkflowCategory.default('user'), + version: z.literal('2.0.0'), + }), +}); +export type WorkflowV2 = z.infer; +// #endregion diff --git a/invokeai/frontend/web/src/features/nodes/types/workflow.ts b/invokeai/frontend/web/src/features/nodes/types/workflow.ts index 723a354013b..adad7c0f219 100644 --- a/invokeai/frontend/web/src/features/nodes/types/workflow.ts +++ b/invokeai/frontend/web/src/features/nodes/types/workflow.ts @@ -24,16 +24,12 @@ export const zWorkflowInvocationNode = z.object({ id: z.string().trim().min(1), type: z.literal('invocation'), data: zInvocationNodeData, - width: zDimension, - height: zDimension, position: zXYPosition, }); export const zWorkflowNotesNode = z.object({ id: z.string().trim().min(1), type: z.literal('notes'), data: zNotesNodeData, - width: zDimension, - height: zDimension, position: zXYPosition, }); export const zWorkflowNode = z.union([zWorkflowInvocationNode, zWorkflowNotesNode]); @@ -68,7 +64,7 @@ export type WorkflowEdge = z.infer; // #endregion // #region Workflow -export const zWorkflowV2 = z.object({ +export const zWorkflowV3 = z.object({ id: z.string().min(1).optional(), name: z.string(), author: z.string(), @@ -82,8 +78,8 @@ export const zWorkflowV2 = z.object({ exposedFields: z.array(zFieldIdentifier), meta: z.object({ category: zWorkflowCategory.default('user'), - version: z.literal('2.0.0'), + version: z.literal('3.0.0'), }), }); -export type WorkflowV2 = z.infer; +export type WorkflowV3 = z.infer; // #endregion diff --git a/invokeai/frontend/web/src/features/nodes/util/node/buildInvocationNode.ts b/invokeai/frontend/web/src/features/nodes/util/node/buildInvocationNode.ts index ea40bd4660f..af19aa86eaf 100644 --- a/invokeai/frontend/web/src/features/nodes/util/node/buildInvocationNode.ts +++ b/invokeai/frontend/web/src/features/nodes/util/node/buildInvocationNode.ts @@ -1,5 +1,5 @@ import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants'; -import type { FieldInputInstance, FieldOutputInstance } from 'features/nodes/types/field'; +import type { FieldInputInstance } from 'features/nodes/types/field'; import type { InvocationNode, InvocationTemplate } from 'features/nodes/types/invocation'; import { buildFieldInputInstance } from 'features/nodes/util/schema/buildFieldInputInstance'; import { reduce } from 'lodash-es'; @@ -24,25 +24,6 @@ export const buildInvocationNode = (position: XYPosition, template: InvocationTe {} as Record ); - const outputs = reduce( - template.outputs, - (outputsAccumulator, outputTemplate, outputName) => { - const fieldId = uuidv4(); - - const outputFieldValue: FieldOutputInstance = { - id: fieldId, - name: outputName, - type: outputTemplate.type, - fieldKind: 'output', - }; - - outputsAccumulator[outputName] = outputFieldValue; - - return outputsAccumulator; - }, - {} as Record - ); - const node: InvocationNode = { ...SHARED_NODE_PROPERTIES, id: nodeId, @@ -58,7 +39,6 @@ export const buildInvocationNode = (position: XYPosition, template: InvocationTe isIntermediate: type === 'save_image' ? false : true, useCache: template.useCache, inputs, - outputs, }, }; diff --git a/invokeai/frontend/web/src/features/nodes/util/node/nodeUpdate.ts b/invokeai/frontend/web/src/features/nodes/util/node/nodeUpdate.ts index f195c49d30a..5ece51d0f30 100644 --- a/invokeai/frontend/web/src/features/nodes/util/node/nodeUpdate.ts +++ b/invokeai/frontend/web/src/features/nodes/util/node/nodeUpdate.ts @@ -54,6 +54,5 @@ export const updateNode = (node: InvocationNode, template: InvocationTemplate): // Remove any fields that are not in the template clone.data.inputs = pick(clone.data.inputs, keys(defaults.data.inputs)); - clone.data.outputs = pick(clone.data.outputs, keys(defaults.data.outputs)); return clone; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts index dd3cf0ad7b4..f8097566c95 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts @@ -23,11 +23,8 @@ const FIELD_VALUE_FALLBACK_MAP: Record = export const buildFieldInputInstance = (id: string, template: FieldInputTemplate): FieldInputInstance => { const fieldInstance: FieldInputInstance = { - id, name: template.name, - type: template.type, label: '', - fieldKind: 'input' as const, value: template.default ?? get(FIELD_VALUE_FALLBACK_MAP, template.type.name), }; diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts index 70775a98823..720da164648 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts @@ -2,8 +2,8 @@ import { logger } from 'app/logging/logger'; import { parseify } from 'common/util/serialize'; import type { NodesState, WorkflowsState } from 'features/nodes/store/types'; import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation'; -import type { WorkflowV2 } from 'features/nodes/types/workflow'; -import { zWorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowV3 } from 'features/nodes/types/workflow'; +import { zWorkflowV3 } from 'features/nodes/types/workflow'; import i18n from 'i18n'; import { cloneDeep, pick } from 'lodash-es'; import { fromZodError } from 'zod-validation-error'; @@ -25,14 +25,14 @@ const workflowKeys = [ 'exposedFields', 'meta', 'id', -] satisfies (keyof WorkflowV2)[]; +] satisfies (keyof WorkflowV3)[]; -export type BuildWorkflowFunction = (arg: BuildWorkflowArg) => WorkflowV2; +export type BuildWorkflowFunction = (arg: BuildWorkflowArg) => WorkflowV3; -export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflow }: BuildWorkflowArg): WorkflowV2 => { +export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflow }: BuildWorkflowArg): WorkflowV3 => { const clonedWorkflow = pick(cloneDeep(workflow), workflowKeys); - const newWorkflow: WorkflowV2 = { + const newWorkflow: WorkflowV3 = { ...clonedWorkflow, nodes: [], edges: [], @@ -45,8 +45,6 @@ export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflo type: node.type, data: cloneDeep(node.data), position: { ...node.position }, - width: node.width, - height: node.height, }); } else if (isNotesNode(node) && node.type) { newWorkflow.nodes.push({ @@ -54,8 +52,6 @@ export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflo type: node.type, data: cloneDeep(node.data), position: { ...node.position }, - width: node.width, - height: node.height, }); } }); @@ -83,12 +79,12 @@ export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflo return newWorkflow; }; -export const buildWorkflowWithValidation = ({ nodes, edges, workflow }: BuildWorkflowArg): WorkflowV2 | null => { +export const buildWorkflowWithValidation = ({ nodes, edges, workflow }: BuildWorkflowArg): WorkflowV3 | null => { // builds what really, really should be a valid workflow const workflowToValidate = buildWorkflowFast({ nodes, edges, workflow }); // but bc we are storing this in the DB, let's be extra sure - const result = zWorkflowV2.safeParse(workflowToValidate); + const result = zWorkflowV3.safeParse(workflowToValidate); if (!result.success) { const { message } = fromZodError(result.error, { diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts index a2677f3d174..a023c96ba92 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts @@ -6,8 +6,10 @@ import { zSemVer } from 'features/nodes/types/semver'; import { FIELD_TYPE_V1_TO_FIELD_TYPE_V2_MAPPING } from 'features/nodes/types/v1/fieldTypeMap'; import type { WorkflowV1 } from 'features/nodes/types/v1/workflowV1'; import { zWorkflowV1 } from 'features/nodes/types/v1/workflowV1'; -import type { WorkflowV2 } from 'features/nodes/types/workflow'; -import { zWorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowV2 } from 'features/nodes/types/v2/workflow'; +import { zWorkflowV2 } from 'features/nodes/types/v2/workflow'; +import type { WorkflowV3 } from 'features/nodes/types/workflow'; +import { zWorkflowV3 } from 'features/nodes/types/workflow'; import { t } from 'i18next'; import { forEach } from 'lodash-es'; import { z } from 'zod'; @@ -30,7 +32,7 @@ const zWorkflowMetaVersion = z.object({ * - Workflow schema version bumped to 2.0.0 */ const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => { - const invocationTemplates = $store.get()?.getState().nodeTemplates.templates; + const invocationTemplates = $store.get()?.getState().nodes.templates; if (!invocationTemplates) { throw new Error(t('app.storeNotInitialized')); @@ -70,26 +72,34 @@ const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => { return zWorkflowV2.parse(workflowToMigrate); }; +const migrateV2toV3 = (workflowToMigrate: WorkflowV2): WorkflowV3 => { + // Bump version + (workflowToMigrate as unknown as WorkflowV3).meta.version = '3.0.0'; + // Parsing strips out any extra properties not in the latest version + return zWorkflowV3.parse(workflowToMigrate); +}; + /** * Parses a workflow and migrates it to the latest version if necessary. */ -export const parseAndMigrateWorkflow = (data: unknown): WorkflowV2 => { +export const parseAndMigrateWorkflow = (data: unknown): WorkflowV3 => { const workflowVersionResult = zWorkflowMetaVersion.safeParse(data); if (!workflowVersionResult.success) { throw new WorkflowVersionError(t('nodes.unableToGetWorkflowVersion')); } - const { version } = workflowVersionResult.data.meta; + let workflow = data as WorkflowV1 | WorkflowV2 | WorkflowV3; - if (version === '1.0.0') { - const v1 = zWorkflowV1.parse(data); - return migrateV1toV2(v1); + if (workflow.meta.version === '1.0.0') { + const v1 = zWorkflowV1.parse(workflow); + workflow = migrateV1toV2(v1); } - if (version === '2.0.0') { - return zWorkflowV2.parse(data); + if (workflow.meta.version === '2.0.0') { + const v2 = zWorkflowV2.parse(workflow); + workflow = migrateV2toV3(v2); } - throw new WorkflowVersionError(t('nodes.unrecognizedWorkflowVersion', { version })); + return workflow as WorkflowV3; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts index 848d2aee77a..5096e588b06 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts @@ -1,6 +1,6 @@ import { parseify } from 'common/util/serialize'; import type { InvocationTemplate } from 'features/nodes/types/invocation'; -import type { WorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowV3 } from 'features/nodes/types/workflow'; import { isWorkflowInvocationNode } from 'features/nodes/types/workflow'; import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate'; import { t } from 'i18next'; @@ -16,7 +16,7 @@ type WorkflowWarning = { }; type ValidateWorkflowResult = { - workflow: WorkflowV2; + workflow: WorkflowV3; warnings: WorkflowWarning[]; }; diff --git a/invokeai/frontend/web/src/features/workflowLibrary/hooks/useSaveWorkflow.ts b/invokeai/frontend/web/src/features/workflowLibrary/hooks/useSaveWorkflow.ts index 5d484b68971..7b49d70213f 100644 --- a/invokeai/frontend/web/src/features/workflowLibrary/hooks/useSaveWorkflow.ts +++ b/invokeai/frontend/web/src/features/workflowLibrary/hooks/useSaveWorkflow.ts @@ -3,7 +3,7 @@ import { useToast } from '@invoke-ai/ui-library'; import { useAppDispatch } from 'app/store/storeHooks'; import { $builtWorkflow } from 'features/nodes/hooks/useWorkflowWatcher'; import { workflowIDChanged, workflowSaved } from 'features/nodes/store/workflowSlice'; -import type { WorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowV3 } from 'features/nodes/types/workflow'; import { workflowUpdated } from 'features/workflowLibrary/store/actions'; import { useCallback, useRef } from 'react'; import { useTranslation } from 'react-i18next'; @@ -18,7 +18,7 @@ type UseSaveLibraryWorkflowReturn = { type UseSaveLibraryWorkflow = () => UseSaveLibraryWorkflowReturn; -export const isWorkflowWithID = (workflow: WorkflowV2): workflow is O.Required => +export const isWorkflowWithID = (workflow: WorkflowV3): workflow is O.Required => Boolean(workflow.id); export const useSaveLibraryWorkflow: UseSaveLibraryWorkflow = () => { From 96ae22c7e0ac7be0ead1cee5e4537c5137158704 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 13 Feb 2024 22:51:44 +1100 Subject: [PATCH 069/100] feat(ui): add vitest - Add vitest. - Consolidate vite configs into single file (easier to config everything based on env for testing) --- invokeai/frontend/web/config/common.mts | 12 - .../frontend/web/config/vite.app.config.mts | 33 --- .../web/config/vite.package.config.mts | 46 ---- invokeai/frontend/web/package.json | 7 +- invokeai/frontend/web/pnpm-lock.yaml | 222 +++++++++++++++++- invokeai/frontend/web/vite.config.mts | 88 ++++++- 6 files changed, 306 insertions(+), 102 deletions(-) delete mode 100644 invokeai/frontend/web/config/common.mts delete mode 100644 invokeai/frontend/web/config/vite.app.config.mts delete mode 100644 invokeai/frontend/web/config/vite.package.config.mts diff --git a/invokeai/frontend/web/config/common.mts b/invokeai/frontend/web/config/common.mts deleted file mode 100644 index fd559cabd1e..00000000000 --- a/invokeai/frontend/web/config/common.mts +++ /dev/null @@ -1,12 +0,0 @@ -import react from '@vitejs/plugin-react-swc'; -import { visualizer } from 'rollup-plugin-visualizer'; -import type { PluginOption, UserConfig } from 'vite'; -import eslint from 'vite-plugin-eslint'; -import tsconfigPaths from 'vite-tsconfig-paths'; - -export const commonPlugins: UserConfig['plugins'] = [ - react(), - eslint(), - tsconfigPaths(), - visualizer() as unknown as PluginOption, -]; diff --git a/invokeai/frontend/web/config/vite.app.config.mts b/invokeai/frontend/web/config/vite.app.config.mts deleted file mode 100644 index 9683ed26a48..00000000000 --- a/invokeai/frontend/web/config/vite.app.config.mts +++ /dev/null @@ -1,33 +0,0 @@ -import type { UserConfig } from 'vite'; - -import { commonPlugins } from './common.mjs'; - -export const appConfig: UserConfig = { - base: './', - plugins: [...commonPlugins], - build: { - chunkSizeWarningLimit: 1500, - }, - server: { - // Proxy HTTP requests to the flask server - proxy: { - // Proxy socket.io to the nodes socketio server - '/ws/socket.io': { - target: 'ws://127.0.0.1:9090', - ws: true, - }, - // Proxy openapi schema definiton - '/openapi.json': { - target: 'http://127.0.0.1:9090/openapi.json', - rewrite: (path) => path.replace(/^\/openapi.json/, ''), - changeOrigin: true, - }, - // proxy nodes api - '/api/v1': { - target: 'http://127.0.0.1:9090/api/v1', - rewrite: (path) => path.replace(/^\/api\/v1/, ''), - changeOrigin: true, - }, - }, - }, -}; diff --git a/invokeai/frontend/web/config/vite.package.config.mts b/invokeai/frontend/web/config/vite.package.config.mts deleted file mode 100644 index 3c05d52e005..00000000000 --- a/invokeai/frontend/web/config/vite.package.config.mts +++ /dev/null @@ -1,46 +0,0 @@ -import path from 'path'; -import type { UserConfig } from 'vite'; -import cssInjectedByJsPlugin from 'vite-plugin-css-injected-by-js'; -import dts from 'vite-plugin-dts'; - -import { commonPlugins } from './common.mjs'; - -export const packageConfig: UserConfig = { - base: './', - plugins: [ - ...commonPlugins, - dts({ - insertTypesEntry: true, - }), - cssInjectedByJsPlugin(), - ], - build: { - cssCodeSplit: true, - lib: { - entry: path.resolve(__dirname, '../src/index.ts'), - name: 'InvokeAIUI', - fileName: (format) => `invoke-ai-ui.${format}.js`, - }, - rollupOptions: { - external: ['react', 'react-dom', '@emotion/react', '@chakra-ui/react', '@invoke-ai/ui-library'], - output: { - globals: { - react: 'React', - 'react-dom': 'ReactDOM', - '@emotion/react': 'EmotionReact', - '@invoke-ai/ui-library': 'UiLibrary', - }, - }, - }, - }, - resolve: { - alias: { - app: path.resolve(__dirname, '../src/app'), - assets: path.resolve(__dirname, '../src/assets'), - common: path.resolve(__dirname, '../src/common'), - features: path.resolve(__dirname, '../src/features'), - services: path.resolve(__dirname, '../src/services'), - theme: path.resolve(__dirname, '../src/theme'), - }, - }, -}; diff --git a/invokeai/frontend/web/package.json b/invokeai/frontend/web/package.json index cd95183c7a4..b2838e538ce 100644 --- a/invokeai/frontend/web/package.json +++ b/invokeai/frontend/web/package.json @@ -33,7 +33,9 @@ "preinstall": "npx only-allow pnpm", "storybook": "storybook dev -p 6006", "build-storybook": "storybook build", - "unimported": "npx unimported" + "unimported": "npx unimported", + "test": "vitest", + "test:no-watch": "vitest --no-watch" }, "madge": { "excludeRegExp": [ @@ -157,7 +159,8 @@ "vite-plugin-css-injected-by-js": "^3.3.1", "vite-plugin-dts": "^3.7.1", "vite-plugin-eslint": "^1.8.1", - "vite-tsconfig-paths": "^4.3.1" + "vite-tsconfig-paths": "^4.3.1", + "vitest": "^1.2.2" }, "pnpm": { "patchedDependencies": { diff --git a/invokeai/frontend/web/pnpm-lock.yaml b/invokeai/frontend/web/pnpm-lock.yaml index 1d9083d1b44..f3bf68cf1da 100644 --- a/invokeai/frontend/web/pnpm-lock.yaml +++ b/invokeai/frontend/web/pnpm-lock.yaml @@ -215,7 +215,7 @@ devDependencies: version: 7.6.10(react-dom@18.2.0)(react@18.2.0)(typescript@5.3.3)(vite@5.0.12) '@storybook/test': specifier: ^7.6.10 - version: 7.6.10 + version: 7.6.10(vitest@1.2.2) '@storybook/theming': specifier: ^7.6.10 version: 7.6.10(react-dom@18.2.0)(react@18.2.0) @@ -318,6 +318,9 @@ devDependencies: vite-tsconfig-paths: specifier: ^4.3.1 version: 4.3.1(typescript@5.3.3)(vite@5.0.12) + vitest: + specifier: ^1.2.2 + version: 1.2.2(@types/node@20.11.5) packages: @@ -5464,7 +5467,7 @@ packages: - supports-color dev: true - /@storybook/test@7.6.10: + /@storybook/test@7.6.10(vitest@1.2.2): resolution: {integrity: sha512-dn/T+HcWOBlVh3c74BHurp++BaqBoQgNbSIaXlYDpJoZ+DzNIoEQVsWFYm5gCbtKK27iFd4n52RiQI3f6Vblqw==} dependencies: '@storybook/client-logger': 7.6.10 @@ -5472,7 +5475,7 @@ packages: '@storybook/instrumenter': 7.6.10 '@storybook/preview-api': 7.6.10 '@testing-library/dom': 9.3.4 - '@testing-library/jest-dom': 6.2.0 + '@testing-library/jest-dom': 6.2.0(vitest@1.2.2) '@testing-library/user-event': 14.3.0(@testing-library/dom@9.3.4) '@types/chai': 4.3.11 '@vitest/expect': 0.34.7 @@ -5652,7 +5655,7 @@ packages: pretty-format: 27.5.1 dev: true - /@testing-library/jest-dom@6.2.0: + /@testing-library/jest-dom@6.2.0(vitest@1.2.2): resolution: {integrity: sha512-+BVQlJ9cmEn5RDMUS8c2+TU6giLvzaHZ8sU/x0Jj7fk+6/46wPdwlgOPcpxS17CjcanBi/3VmGMqVr2rmbUmNw==} engines: {node: '>=14', npm: '>=6', yarn: '>=1'} peerDependencies: @@ -5678,6 +5681,7 @@ packages: dom-accessibility-api: 0.6.3 lodash: 4.17.21 redent: 3.0.0 + vitest: 1.2.2(@types/node@20.11.5) dev: true /@testing-library/user-event@14.3.0(@testing-library/dom@9.3.4): @@ -6490,12 +6494,42 @@ packages: chai: 4.4.1 dev: true + /@vitest/expect@1.2.2: + resolution: {integrity: sha512-3jpcdPAD7LwHUUiT2pZTj2U82I2Tcgg2oVPvKxhn6mDI2On6tfvPQTjAI4628GUGDZrCm4Zna9iQHm5cEexOAg==} + dependencies: + '@vitest/spy': 1.2.2 + '@vitest/utils': 1.2.2 + chai: 4.4.1 + dev: true + + /@vitest/runner@1.2.2: + resolution: {integrity: sha512-JctG7QZ4LSDXr5CsUweFgcpEvrcxOV1Gft7uHrvkQ+fsAVylmWQvnaAr/HDp3LAH1fztGMQZugIheTWjaGzYIg==} + dependencies: + '@vitest/utils': 1.2.2 + p-limit: 5.0.0 + pathe: 1.1.2 + dev: true + + /@vitest/snapshot@1.2.2: + resolution: {integrity: sha512-SmGY4saEw1+bwE1th6S/cZmPxz/Q4JWsl7LvbQIky2tKE35US4gd0Mjzqfr84/4OD0tikGWaWdMja/nWL5NIPA==} + dependencies: + magic-string: 0.30.5 + pathe: 1.1.2 + pretty-format: 29.7.0 + dev: true + /@vitest/spy@0.34.7: resolution: {integrity: sha512-NMMSzOY2d8L0mcOt4XcliDOS1ISyGlAXuQtERWVOoVHnKwmG+kKhinAiGw3dTtMQWybfa89FG8Ucg9tiC/FhTQ==} dependencies: tinyspy: 2.2.0 dev: true + /@vitest/spy@1.2.2: + resolution: {integrity: sha512-k9Gcahssw8d7X3pSLq3e3XEu/0L78mUkCjivUqCQeXJm9clfXR/Td8+AP+VC1O6fKPIDLcHDTAmBOINVuv6+7g==} + dependencies: + tinyspy: 2.2.0 + dev: true + /@vitest/utils@0.34.7: resolution: {integrity: sha512-ziAavQLpCYS9sLOorGrFFKmy2gnfiNU0ZJ15TsMz/K92NAPS/rp9K4z6AJQQk5Y8adCy4Iwpxy7pQumQ/psnRg==} dependencies: @@ -6504,6 +6538,15 @@ packages: pretty-format: 29.7.0 dev: true + /@vitest/utils@1.2.2: + resolution: {integrity: sha512-WKITBHLsBHlpjnDQahr+XK6RE7MiAsgrIkr0pGhQ9ygoxBfUeG0lUG5iLlzqjmKSlBv3+j5EGsriBzh+C3Tq9g==} + dependencies: + diff-sequences: 29.6.3 + estree-walker: 3.0.3 + loupe: 2.3.7 + pretty-format: 29.7.0 + dev: true + /@volar/language-core@1.11.1: resolution: {integrity: sha512-dOcNn3i9GgZAcJt43wuaEykSluAuOkQgzni1cuxLxTV0nJKanQztp7FxyswdRILaKH+P2XZMPRp2S4MV/pElCw==} dependencies: @@ -7184,6 +7227,11 @@ packages: engines: {node: '>=0.4.0'} dev: true + /acorn-walk@8.3.2: + resolution: {integrity: sha512-cjkyv4OtNCIeqhHrfS81QWXoCBPExR/J62oyEqepVw8WaQeSqpW2uhuLPh1m9eWhDuOo/jUXVTlifvesOWp/4A==} + engines: {node: '>=0.4.0'} + dev: true + /acorn@7.4.1: resolution: {integrity: sha512-nQyp0o1/mNdbTO1PO6kHkwSrmgZ0MT/jCCpNiwbUjGoRN4dlBhqJtoQuCnEOKzgTVwg0ZWiCoQy6SxMebQVh8A==} engines: {node: '>=0.4.0'} @@ -7661,6 +7709,11 @@ packages: engines: {node: '>= 0.8'} dev: true + /cac@6.7.14: + resolution: {integrity: sha512-b6Ilus+c3RrdDk+JhLKUAQfzzgLEPy6wcXqS7f/xe1EETvsDP6GORG7SFuOs6cID5YkqchW/LXZbX5bc8j7ZcQ==} + engines: {node: '>=8'} + dev: true + /call-bind@1.0.5: resolution: {integrity: sha512-C3nQxfFZxFRVoJoGKKI8y3MOEo129NQ+FgQ08iye+Mk4zNZZGdjfs06bVTr+DBSlA66Q2VEcMki/cUCP4SercQ==} dependencies: @@ -9173,6 +9226,12 @@ packages: resolution: {integrity: sha512-Rfkk/Mp/DL7JVje3u18FxFujQlTNR2q6QfMSMB7AvCBx91NGj/ba3kCfza0f6dVDbw7YlRf/nDrn7pQrCCyQ/w==} dev: true + /estree-walker@3.0.3: + resolution: {integrity: sha512-7RUKfXgSMMkzt6ZuXmqapOurLGPPfgj6l9uRZ7lRGolvk0y2yocc35LdcxKC5PQZdn2DMqioAQ2NoWcrTKmm6g==} + dependencies: + '@types/estree': 1.0.5 + dev: true + /esutils@2.0.3: resolution: {integrity: sha512-kVscqXk4OCp68SZ0dkgEKVi6/8ij300KBWTJq32P/dYeWTSwK41WyTxalN1eRmA5Z9UU/LX9D7FWSmV9SAYx6g==} engines: {node: '>=0.10.0'} @@ -10547,6 +10606,10 @@ packages: hasBin: true dev: true + /jsonc-parser@3.2.1: + resolution: {integrity: sha512-AilxAyFOAcK5wA1+LeaySVBrHsGQvUFCDWXKpZjzaL0PqW+xfBOttn8GNtWKFWqneyMZj41MWF9Kl6iPWLwgOA==} + dev: true + /jsondiffpatch@0.6.0: resolution: {integrity: sha512-3QItJOXp2AP1uv7waBkao5nCvhEv+QmJAd38Ybq7wNI74Q+BBmnLn4EDKz6yI9xGAIQoUF87qHt+kc1IVxB4zQ==} engines: {node: ^18.0.0 || >=20.0.0} @@ -10648,6 +10711,14 @@ packages: engines: {node: ^12.20.0 || ^14.13.1 || >=16.0.0} dev: true + /local-pkg@0.5.0: + resolution: {integrity: sha512-ok6z3qlYyCDS4ZEU27HaU6x/xZa9Whf8jD4ptH5UZTQYZVYeb9bnZ3ojVhiJNLiXK1Hfc0GNbLXcmZ5plLDDBg==} + engines: {node: '>=14'} + dependencies: + mlly: 1.5.0 + pkg-types: 1.0.3 + dev: true + /locate-path@3.0.0: resolution: {integrity: sha512-7AO748wWnIhNqAuaty2ZWHkQHRSNfPVIsPIfwEOWO22AmaoVrWavlOcMR5nzTLNYvp36X220/maaRsrec1G65A==} engines: {node: '>=6'} @@ -10986,6 +11057,15 @@ packages: hasBin: true dev: true + /mlly@1.5.0: + resolution: {integrity: sha512-NPVQvAY1xr1QoVeG0cy8yUYC7FQcOx6evl/RjT1wL5FvzPnzOysoqB/jmx/DhssT2dYa8nxECLAaFI/+gVLhDQ==} + dependencies: + acorn: 8.11.3 + pathe: 1.1.2 + pkg-types: 1.0.3 + ufo: 1.3.2 + dev: true + /module-definition@3.4.0: resolution: {integrity: sha512-XxJ88R1v458pifaSkPNLUTdSPNVGMP2SXVncVmApGO+gAfrLANiYe6JofymCzVceGOMwQE2xogxBSc8uB7XegA==} engines: {node: '>=6.0'} @@ -11380,6 +11460,13 @@ packages: yocto-queue: 0.1.0 dev: true + /p-limit@5.0.0: + resolution: {integrity: sha512-/Eaoq+QyLSiXQ4lyYV23f14mZRQcXnxfHrN0vCai+ak9G0pp9iEQukIIZq5NccEvwRB8PUnZT0KsOoDCINS1qQ==} + engines: {node: '>=18'} + dependencies: + yocto-queue: 1.0.0 + dev: true + /p-locate@3.0.0: resolution: {integrity: sha512-x+12w/To+4GFfgJhBEpiDcLozRJGegY+Ei7/z0tSLkMmxGZNybVMSfWj9aJn8Z5Fc7dBUNJOOVgPv2H7IwulSQ==} engines: {node: '>=6'} @@ -11550,6 +11637,14 @@ packages: find-up: 5.0.0 dev: true + /pkg-types@1.0.3: + resolution: {integrity: sha512-nN7pYi0AQqJnoLPC9eHFQ8AcyaixBUOwvqc5TDnIKCMEE6I0y8P7OKA7fPexsXGCGxQDl/cmrLAp26LhcwxZ4A==} + dependencies: + jsonc-parser: 3.2.1 + mlly: 1.5.0 + pathe: 1.1.2 + dev: true + /pluralize@8.0.0: resolution: {integrity: sha512-Nc3IT5yHzflTfbjgqWcCPpo7DaKy4FnpB0l/zCAW0Tc7jxAiuqSxHasntB3D7887LSrA93kDJ9IXovxJYxyLCA==} engines: {node: '>=4'} @@ -12850,6 +12945,10 @@ packages: object-inspect: 1.13.1 dev: true + /siginfo@2.0.0: + resolution: {integrity: sha512-ybx0WO1/8bSBLEWXZvEd7gMW3Sn3JFlW3TvX1nREbDLRNQNaeNN8WK0meBwPdAaOI7TtRRRJn/Es1zhrrCHu7g==} + dev: true + /signal-exit@3.0.7: resolution: {integrity: sha512-wnD2ZE+l+SPC/uoS0vXeE9L1+0wuaMqKlfz9AMUo38JsyLSBWSFcHR1Rri62LZc12vLr1gb3jl7iwQhgwpAbGQ==} dev: true @@ -12968,6 +13067,10 @@ packages: stackframe: 1.3.4 dev: false + /stackback@0.0.2: + resolution: {integrity: sha512-1XMJE5fQo1jGH6Y/7ebnwPOBEkIEnT4QF32d5R1+VXdXveM0IBMJt8zfaxX1P3QhVwrYe+576+jkANtSS2mBbw==} + dev: true + /stackframe@1.3.4: resolution: {integrity: sha512-oeVtt7eWQS+Na6F//S4kJ2K2VbRlS9D43mAlMyVpVWovy9o+jfgH8O9agzANzaiLjclA0oYzUXEM4PurhSUChw==} dev: false @@ -12992,6 +13095,10 @@ packages: engines: {node: '>= 0.8'} dev: true + /std-env@3.7.0: + resolution: {integrity: sha512-JPbdCEQLj1w5GilpiHAx3qJvFndqybBysA3qUOnznweH4QbNYUsW/ea8QzSrnh0vNsezMMw5bcVool8lM0gwzg==} + dev: true + /stop-iteration-iterator@1.0.0: resolution: {integrity: sha512-iCGQj+0l0HOdZ2AEeBADlsRC+vsnDsZsbdSiH1yNSjcfKM7fdpCMfqAL/dwF5BLiw/XhRft/Wax6zQbhq2BcjQ==} engines: {node: '>= 0.4'} @@ -13161,6 +13268,12 @@ packages: engines: {node: '>=8'} dev: true + /strip-literal@1.3.0: + resolution: {integrity: sha512-PugKzOsyXpArk0yWmUwqOZecSO0GH0bPoctLcqNDH9J04pVW3lflYE0ujElBGTloevcxF5MofAOZ7C5l2b+wLg==} + dependencies: + acorn: 8.11.3 + dev: true + /stylis@4.2.0: resolution: {integrity: sha512-Orov6g6BB1sDfYgzWfTHDOxamtX1bE/zo104Dh9e6fqJ3PooipYyfJ0pUmrZO2wAvO8YbEyeFrkV91XTsGMSrw==} dev: false @@ -13311,6 +13424,15 @@ packages: /tiny-invariant@1.3.1: resolution: {integrity: sha512-AD5ih2NlSssTCwsMznbvwMZpJ1cbhkGd2uueNxzv2jDlEeZdU04JQfRnggJQ8DrcVBGjAsCKwFBbDlVNtEMlzw==} + /tinybench@2.6.0: + resolution: {integrity: sha512-N8hW3PG/3aOoZAN5V/NSAEDz0ZixDSSt5b/a05iqtpgfLWMSVuCo7w0k2vVvEjdrIoeGqZzweX2WlyioNIHchA==} + dev: true + + /tinypool@0.8.2: + resolution: {integrity: sha512-SUszKYe5wgsxnNOVlBYO6IC+8VGWdVGZWAqUxp3UErNBtptZvWbwyUOyzNL59zigz2rCA92QiL3wvG+JDSdJdQ==} + engines: {node: '>=14.0.0'} + dev: true + /tinyspy@2.2.0: resolution: {integrity: sha512-d2eda04AN/cPOR89F7Xv5bK/jrQEhmcLFe6HFldoeO9AJtps+fqEnh486vnT/8y4bw38pSyxDcTCAq+Ks2aJTg==} engines: {node: '>=14.0.0'} @@ -13828,6 +13950,27 @@ packages: engines: {node: '>= 0.8'} dev: true + /vite-node@1.2.2(@types/node@20.11.5): + resolution: {integrity: sha512-1as4rDTgVWJO3n1uHmUYqq7nsFgINQ9u+mRcXpjeOMJUmviqNKjcZB7UfRZrlM7MjYXMKpuWp5oGkjaFLnjawg==} + engines: {node: ^18.0.0 || >=20.0.0} + hasBin: true + dependencies: + cac: 6.7.14 + debug: 4.3.4 + pathe: 1.1.2 + picocolors: 1.0.0 + vite: 5.0.12(@types/node@20.11.5) + transitivePeerDependencies: + - '@types/node' + - less + - lightningcss + - sass + - stylus + - sugarss + - supports-color + - terser + dev: true + /vite-plugin-css-injected-by-js@3.3.1(vite@5.0.12): resolution: {integrity: sha512-PjM/X45DR3/V1K1fTRs8HtZHEQ55kIfdrn+dzaqNBFrOYO073SeSNCxp4j7gSYhV9NffVHaEnOL4myoko0ePAg==} peerDependencies: @@ -13926,6 +14069,63 @@ packages: fsevents: 2.3.3 dev: true + /vitest@1.2.2(@types/node@20.11.5): + resolution: {integrity: sha512-d5Ouvrnms3GD9USIK36KG8OZ5bEvKEkITFtnGv56HFaSlbItJuYr7hv2Lkn903+AvRAgSixiamozUVfORUekjw==} + engines: {node: ^18.0.0 || >=20.0.0} + hasBin: true + peerDependencies: + '@edge-runtime/vm': '*' + '@types/node': ^18.0.0 || >=20.0.0 + '@vitest/browser': ^1.0.0 + '@vitest/ui': ^1.0.0 + happy-dom: '*' + jsdom: '*' + peerDependenciesMeta: + '@edge-runtime/vm': + optional: true + '@types/node': + optional: true + '@vitest/browser': + optional: true + '@vitest/ui': + optional: true + happy-dom: + optional: true + jsdom: + optional: true + dependencies: + '@types/node': 20.11.5 + '@vitest/expect': 1.2.2 + '@vitest/runner': 1.2.2 + '@vitest/snapshot': 1.2.2 + '@vitest/spy': 1.2.2 + '@vitest/utils': 1.2.2 + acorn-walk: 8.3.2 + cac: 6.7.14 + chai: 4.4.1 + debug: 4.3.4 + execa: 8.0.1 + local-pkg: 0.5.0 + magic-string: 0.30.5 + pathe: 1.1.2 + picocolors: 1.0.0 + std-env: 3.7.0 + strip-literal: 1.3.0 + tinybench: 2.6.0 + tinypool: 0.8.2 + vite: 5.0.12(@types/node@20.11.5) + vite-node: 1.2.2(@types/node@20.11.5) + why-is-node-running: 2.2.2 + transitivePeerDependencies: + - less + - lightningcss + - sass + - stylus + - sugarss + - supports-color + - terser + dev: true + /void-elements@3.1.0: resolution: {integrity: sha512-Dhxzh5HZuiHQhbvTW9AMetFfBHDMYpo23Uo9btPXgdYP+3T5S+p+jgNy7spra+veYhBP2dCSgxR/i2Y02h5/6w==} engines: {node: '>=0.10.0'} @@ -14049,6 +14249,15 @@ packages: isexe: 2.0.0 dev: true + /why-is-node-running@2.2.2: + resolution: {integrity: sha512-6tSwToZxTOcotxHeA+qGCq1mVzKR3CwcJGmVcY+QE8SHy6TnpFnh8PAvPNHYr7EcuVeG0QSMxtYCuO1ta/G/oA==} + engines: {node: '>=8'} + hasBin: true + dependencies: + siginfo: 2.0.0 + stackback: 0.0.2 + dev: true + /wordwrap@1.0.0: resolution: {integrity: sha512-gvVzJFlPycKc5dZN4yPkP8w7Dc37BtP1yczEneOb4uq34pXZcvrtRTmWV8W+Ume+XCxKgbjM+nevkyFPMybd4Q==} dev: true @@ -14189,6 +14398,11 @@ packages: engines: {node: '>=10'} dev: true + /yocto-queue@1.0.0: + resolution: {integrity: sha512-9bnSc/HEW2uRy67wc+T8UwauLuPJVn28jb+GtJY16iiKWyvmYJRXVT4UamsAEGQfPohgr2q4Tq0sQbQlxTfi1g==} + engines: {node: '>=12.20'} + dev: true + /z-schema@5.0.5: resolution: {integrity: sha512-D7eujBWkLa3p2sIpJA0d1pr7es+a7m0vFAnZLlCEKq/Ij2k0MLi9Br2UPxoxdYystm5K1yeBGzub0FlYUEWj2Q==} engines: {node: '>=8.0.0'} diff --git a/invokeai/frontend/web/vite.config.mts b/invokeai/frontend/web/vite.config.mts index b76dd24b628..325c6467dee 100644 --- a/invokeai/frontend/web/vite.config.mts +++ b/invokeai/frontend/web/vite.config.mts @@ -1,12 +1,90 @@ +/// +import react from '@vitejs/plugin-react-swc'; +import path from 'path'; +import { visualizer } from 'rollup-plugin-visualizer'; +import type { PluginOption } from 'vite'; import { defineConfig } from 'vite'; - -import { appConfig } from './config/vite.app.config.mjs'; -import { packageConfig } from './config/vite.package.config.mjs'; +import cssInjectedByJsPlugin from 'vite-plugin-css-injected-by-js'; +import dts from 'vite-plugin-dts'; +import eslint from 'vite-plugin-eslint'; +import tsconfigPaths from 'vite-tsconfig-paths'; export default defineConfig(({ mode }) => { if (mode === 'package') { - return packageConfig; + return { + base: './', + plugins: [ + react(), + eslint(), + tsconfigPaths(), + visualizer() as unknown as PluginOption, + dts({ + insertTypesEntry: true, + }), + cssInjectedByJsPlugin(), + ], + build: { + cssCodeSplit: true, + lib: { + entry: path.resolve(__dirname, '../src/index.ts'), + name: 'InvokeAIUI', + fileName: (format) => `invoke-ai-ui.${format}.js`, + }, + rollupOptions: { + external: ['react', 'react-dom', '@emotion/react', '@chakra-ui/react', '@invoke-ai/ui-library'], + output: { + globals: { + react: 'React', + 'react-dom': 'ReactDOM', + '@emotion/react': 'EmotionReact', + '@invoke-ai/ui-library': 'UiLibrary', + }, + }, + }, + }, + resolve: { + alias: { + app: path.resolve(__dirname, '../src/app'), + assets: path.resolve(__dirname, '../src/assets'), + common: path.resolve(__dirname, '../src/common'), + features: path.resolve(__dirname, '../src/features'), + services: path.resolve(__dirname, '../src/services'), + theme: path.resolve(__dirname, '../src/theme'), + }, + }, + }; } - return appConfig; + return { + base: './', + plugins: [react(), mode !== 'test' && eslint(), tsconfigPaths(), visualizer() as unknown as PluginOption], + build: { + chunkSizeWarningLimit: 1500, + }, + server: { + // Proxy HTTP requests to the flask server + proxy: { + // Proxy socket.io to the nodes socketio server + '/ws/socket.io': { + target: 'ws://127.0.0.1:9090', + ws: true, + }, + // Proxy openapi schema definiton + '/openapi.json': { + target: 'http://127.0.0.1:9090/openapi.json', + rewrite: (path) => path.replace(/^\/openapi.json/, ''), + changeOrigin: true, + }, + // proxy nodes api + '/api/v1': { + target: 'http://127.0.0.1:9090/api/v1', + rewrite: (path) => path.replace(/^\/api\/v1/, ''), + changeOrigin: true, + }, + }, + }, + test: { + // + }, + }; }); From f22eb368a34071ce85487e1cd2505d9a88e3aa94 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 13 Feb 2024 22:53:30 +1100 Subject: [PATCH 070/100] feat(ui): add more types of FieldParseError Unfortunately you cannot test for both a specific type of error and match its message. Splitting the error classes makes it easier to test expected error conditions. --- .../web/src/features/nodes/types/error.ts | 5 ++++ .../nodes/util/schema/parseFieldType.ts | 30 +++++++++++-------- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/types/error.ts b/invokeai/frontend/web/src/features/nodes/types/error.ts index 905b487fb04..c3da136c7a8 100644 --- a/invokeai/frontend/web/src/features/nodes/types/error.ts +++ b/invokeai/frontend/web/src/features/nodes/types/error.ts @@ -56,3 +56,8 @@ export class FieldParseError extends Error { this.name = this.constructor.name; } } + +export class UnableToExtractSchemaNameFromRefError extends FieldParseError {} +export class UnsupportedArrayItemType extends FieldParseError {} +export class UnsupportedUnionError extends FieldParseError {} +export class UnsupportedPrimitiveTypeError extends FieldParseError {} \ No newline at end of file diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.ts b/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.ts index 14b1aefd6d3..13da6b38312 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.ts @@ -1,6 +1,12 @@ -import { FieldParseError } from 'features/nodes/types/error'; +import { + FieldParseError, + UnableToExtractSchemaNameFromRefError, + UnsupportedArrayItemType, + UnsupportedPrimitiveTypeError, + UnsupportedUnionError, +} from 'features/nodes/types/error'; import type { FieldType } from 'features/nodes/types/field'; -import type { OpenAPIV3_1SchemaOrRef } from 'features/nodes/types/openapi'; +import type { InvocationFieldSchema, OpenAPIV3_1SchemaOrRef } from 'features/nodes/types/openapi'; import { isArraySchemaObject, isInvocationFieldSchema, @@ -42,7 +48,7 @@ const isCollectionFieldType = (fieldType: string) => { return false; }; -export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType => { +export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef | InvocationFieldSchema): FieldType => { if (isInvocationFieldSchema(schemaObject)) { // Check if this field has an explicit type provided by the node schema const { ui_type } = schemaObject; @@ -72,7 +78,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType // This is a single ref type const name = refObjectToSchemaName(allOf[0]); if (!name) { - throw new FieldParseError(t('nodes.unableToExtractSchemaNameFromRef')); + throw new UnableToExtractSchemaNameFromRefError(t('nodes.unableToExtractSchemaNameFromRef')); } return { name, @@ -95,7 +101,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType if (isRefObject(filteredAnyOf[0])) { const name = refObjectToSchemaName(filteredAnyOf[0]); if (!name) { - throw new FieldParseError(t('nodes.unableToExtractSchemaNameFromRef')); + throw new UnableToExtractSchemaNameFromRefError(t('nodes.unableToExtractSchemaNameFromRef')); } return { @@ -118,7 +124,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType if (filteredAnyOf.length !== 2) { // This is a union of more than 2 types, which we don't support - throw new FieldParseError( + throw new UnsupportedUnionError( t('nodes.unsupportedAnyOfLength', { count: filteredAnyOf.length, }) @@ -159,7 +165,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType }; } - throw new FieldParseError( + throw new UnsupportedUnionError( t('nodes.unsupportedMismatchedUnion', { firstType, secondType, @@ -178,7 +184,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType if (isSchemaObject(schemaObject.items)) { const itemType = schemaObject.items.type; if (!itemType || isArray(itemType)) { - throw new FieldParseError( + throw new UnsupportedArrayItemType( t('nodes.unsupportedArrayItemType', { type: itemType, }) @@ -188,7 +194,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType const name = OPENAPI_TO_FIELD_TYPE_MAP[itemType]; if (!name) { // it's 'null', 'object', or 'array' - skip - throw new FieldParseError( + throw new UnsupportedArrayItemType( t('nodes.unsupportedArrayItemType', { type: itemType, }) @@ -204,7 +210,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType // This is a ref object, extract the type name const name = refObjectToSchemaName(schemaObject.items); if (!name) { - throw new FieldParseError(t('nodes.unableToExtractSchemaNameFromRef')); + throw new UnableToExtractSchemaNameFromRefError(t('nodes.unableToExtractSchemaNameFromRef')); } return { name, @@ -216,7 +222,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType const name = OPENAPI_TO_FIELD_TYPE_MAP[schemaObject.type]; if (!name) { // it's 'null', 'object', or 'array' - skip - throw new FieldParseError( + throw new UnsupportedPrimitiveTypeError( t('nodes.unsupportedArrayItemType', { type: schemaObject.type, }) @@ -232,7 +238,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType } else if (isRefObject(schemaObject)) { const name = refObjectToSchemaName(schemaObject); if (!name) { - throw new FieldParseError(t('nodes.unableToExtractSchemaNameFromRef')); + throw new UnableToExtractSchemaNameFromRefError(t('nodes.unableToExtractSchemaNameFromRef')); } return { name, From f505ec64bafa990e05896bcc09054cfd9d9414d6 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 13 Feb 2024 22:53:52 +1100 Subject: [PATCH 071/100] tests(ui): add `parseFieldType.test.ts` --- .../nodes/util/schema/parseFieldType.test.ts | 379 ++++++++++++++++++ 1 file changed, 379 insertions(+) create mode 100644 invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.test.ts diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.test.ts b/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.test.ts new file mode 100644 index 00000000000..2f4ce48a326 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.test.ts @@ -0,0 +1,379 @@ +import { + UnableToExtractSchemaNameFromRefError, + UnsupportedArrayItemType, + UnsupportedPrimitiveTypeError, + UnsupportedUnionError, +} from 'features/nodes/types/error'; +import type { InvocationFieldSchema, OpenAPIV3_1SchemaOrRef } from 'features/nodes/types/openapi'; +import { parseFieldType, refObjectToSchemaName } from 'features/nodes/util/schema/parseFieldType'; +import { describe, expect, it } from 'vitest'; + +type ParseFieldTypeTestCase = { + name: string; + schema: OpenAPIV3_1SchemaOrRef | InvocationFieldSchema; + expected: { name: string; isCollection: boolean; isCollectionOrScalar: boolean }; +}; + +const primitiveTypes: ParseFieldTypeTestCase[] = [ + { + name: 'Scalar IntegerField', + schema: { type: 'integer' }, + expected: { name: 'IntegerField', isCollection: false, isCollectionOrScalar: false }, + }, + { + name: 'Scalar FloatField', + schema: { type: 'number' }, + expected: { name: 'FloatField', isCollection: false, isCollectionOrScalar: false }, + }, + { + name: 'Scalar StringField', + schema: { type: 'string' }, + expected: { name: 'StringField', isCollection: false, isCollectionOrScalar: false }, + }, + { + name: 'Scalar BooleanField', + schema: { type: 'boolean' }, + expected: { name: 'BooleanField', isCollection: false, isCollectionOrScalar: false }, + }, + { + name: 'Collection IntegerField', + schema: { items: { type: 'integer' }, type: 'array' }, + expected: { name: 'IntegerField', isCollection: true, isCollectionOrScalar: false }, + }, + { + name: 'Collection FloatField', + schema: { items: { type: 'number' }, type: 'array' }, + expected: { name: 'FloatField', isCollection: true, isCollectionOrScalar: false }, + }, + { + name: 'Collection StringField', + schema: { items: { type: 'string' }, type: 'array' }, + expected: { name: 'StringField', isCollection: true, isCollectionOrScalar: false }, + }, + { + name: 'Collection BooleanField', + schema: { items: { type: 'boolean' }, type: 'array' }, + expected: { name: 'BooleanField', isCollection: true, isCollectionOrScalar: false }, + }, + { + name: 'CollectionOrScalar IntegerField', + schema: { + anyOf: [ + { + type: 'integer', + }, + { + items: { + type: 'integer', + }, + type: 'array', + }, + ], + }, + expected: { name: 'IntegerField', isCollection: false, isCollectionOrScalar: true }, + }, + { + name: 'CollectionOrScalar FloatField', + schema: { + anyOf: [ + { + type: 'number', + }, + { + items: { + type: 'number', + }, + type: 'array', + }, + ], + }, + expected: { name: 'FloatField', isCollection: false, isCollectionOrScalar: true }, + }, + { + name: 'CollectionOrScalar StringField', + schema: { + anyOf: [ + { + type: 'string', + }, + { + items: { + type: 'string', + }, + type: 'array', + }, + ], + }, + expected: { name: 'StringField', isCollection: false, isCollectionOrScalar: true }, + }, + { + name: 'CollectionOrScalar BooleanField', + schema: { + anyOf: [ + { + type: 'boolean', + }, + { + items: { + type: 'boolean', + }, + type: 'array', + }, + ], + }, + expected: { name: 'BooleanField', isCollection: false, isCollectionOrScalar: true }, + }, +]; + +const complexTypes: ParseFieldTypeTestCase[] = [ + { + name: 'Scalar ConditioningField', + schema: { + allOf: [ + { + $ref: '#/components/schemas/ConditioningField', + }, + ], + }, + expected: { name: 'ConditioningField', isCollection: false, isCollectionOrScalar: false }, + }, + { + name: 'Nullable Scalar ConditioningField', + schema: { + anyOf: [ + { + $ref: '#/components/schemas/ConditioningField', + }, + { + type: 'null', + }, + ], + }, + expected: { name: 'ConditioningField', isCollection: false, isCollectionOrScalar: false }, + }, + { + name: 'Collection ConditioningField', + schema: { + anyOf: [ + { + items: { + $ref: '#/components/schemas/ConditioningField', + }, + type: 'array', + }, + ], + }, + expected: { name: 'ConditioningField', isCollection: true, isCollectionOrScalar: false }, + }, + { + name: 'Nullable Collection ConditioningField', + schema: { + anyOf: [ + { + items: { + $ref: '#/components/schemas/ConditioningField', + }, + type: 'array', + }, + { + type: 'null', + }, + ], + }, + expected: { name: 'ConditioningField', isCollection: true, isCollectionOrScalar: false }, + }, + { + name: 'CollectionOrScalar ConditioningField', + schema: { + anyOf: [ + { + items: { + $ref: '#/components/schemas/ConditioningField', + }, + type: 'array', + }, + { + $ref: '#/components/schemas/ConditioningField', + }, + ], + }, + expected: { name: 'ConditioningField', isCollection: false, isCollectionOrScalar: true }, + }, + { + name: 'Nullable CollectionOrScalar ConditioningField', + schema: { + anyOf: [ + { + items: { + $ref: '#/components/schemas/ConditioningField', + }, + type: 'array', + }, + { + $ref: '#/components/schemas/ConditioningField', + }, + { + type: 'null', + }, + ], + }, + expected: { name: 'ConditioningField', isCollection: false, isCollectionOrScalar: true }, + }, +]; + +const specialCases: ParseFieldTypeTestCase[] = [ + { + name: 'String EnumField', + schema: { + type: 'string', + enum: ['large', 'base', 'small'], + }, + expected: { name: 'EnumField', isCollection: false, isCollectionOrScalar: false }, + }, + { + name: 'String EnumField with one value', + schema: { + const: 'Some Value', + }, + expected: { name: 'EnumField', isCollection: false, isCollectionOrScalar: false }, + }, + { + name: 'Explicit ui_type (SchedulerField)', + schema: { + type: 'string', + enum: ['ddim', 'ddpm', 'deis'], + ui_type: 'SchedulerField', + }, + expected: { name: 'SchedulerField', isCollection: false, isCollectionOrScalar: false }, + }, + { + name: 'Explicit ui_type (AnyField)', + schema: { + type: 'string', + enum: ['ddim', 'ddpm', 'deis'], + ui_type: 'AnyField', + }, + expected: { name: 'AnyField', isCollection: false, isCollectionOrScalar: false }, + }, + { + name: 'Explicit ui_type (CollectionField)', + schema: { + type: 'string', + enum: ['ddim', 'ddpm', 'deis'], + ui_type: 'CollectionField', + }, + expected: { name: 'CollectionField', isCollection: true, isCollectionOrScalar: false }, + }, +]; + +describe('refObjectToSchemaName', async () => { + it('parses ref object 1', () => { + expect( + refObjectToSchemaName({ + $ref: '#/components/schemas/ImageField', + }) + ).toEqual('ImageField'); + }); + it('parses ref object 2', () => { + expect( + refObjectToSchemaName({ + $ref: '#/components/schemas/T2IAdapterModelField', + }) + ).toEqual('T2IAdapterModelField'); + }); +}); + +describe.concurrent('parseFieldType', async () => { + it.each(primitiveTypes)('parses primitive types ($name)', ({ schema, expected }) => { + expect(parseFieldType(schema)).toEqual(expected); + }); + it.each(complexTypes)('parses complex types ($name)', ({ schema, expected }) => { + expect(parseFieldType(schema)).toEqual(expected); + }); + it.each(specialCases)('parses special case types ($name)', ({ schema, expected }) => { + expect(parseFieldType(schema)).toEqual(expected); + }); + + it('raises if it cannot extract a schema name from a ref', () => { + expect(() => + parseFieldType({ + allOf: [ + { + $ref: '#/components/schemas/', + }, + ], + }) + ).toThrowError(UnableToExtractSchemaNameFromRefError); + }); + + it('raises if it receives a union of mismatched types', () => { + expect(() => + parseFieldType({ + anyOf: [ + { + type: 'string', + }, + { + type: 'integer', + }, + ], + }) + ).toThrowError(UnsupportedUnionError); + }); + + it('raises if it receives a union of mismatched types (excluding null)', () => { + expect(() => + parseFieldType({ + anyOf: [ + { + type: 'string', + }, + { + type: 'integer', + }, + { + type: 'null', + }, + ], + }) + ).toThrowError(UnsupportedUnionError); + }); + + it('raises if it received an unsupported primitive type (object)', () => { + expect(() => + parseFieldType({ + type: 'object', + }) + ).toThrowError(UnsupportedPrimitiveTypeError); + }); + + it('raises if it received an unsupported primitive type (null)', () => { + expect(() => + parseFieldType({ + type: 'null', + }) + ).toThrowError(UnsupportedPrimitiveTypeError); + }); + + it('raises if it received an unsupported array item type (object)', () => { + expect(() => + parseFieldType({ + items: { + type: 'object', + }, + type: 'array', + }) + ).toThrowError(UnsupportedArrayItemType); + }); + + it('raises if it received an unsupported array item type (null)', () => { + expect(() => + parseFieldType({ + items: { + type: 'null', + }, + type: 'array', + }) + ).toThrowError(UnsupportedArrayItemType); + }); +}); From a1307b9f2e591dd01b04f9dbf248d6b94b1cb04b Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 22 Jan 2024 14:37:23 -0500 Subject: [PATCH 072/100] add concept of repo variant --- invokeai/backend/model_manager/config.py | 4 +- invokeai/backend/model_manager/probe.py | 19 ++++++++++ tests/test_model_probe.py | 9 ++++- .../vae/taesdxl-fp16/config.json | 37 +++++++++++++++++++ .../diffusion_pytorch_model.fp16.safetensors | 0 5 files changed, 65 insertions(+), 4 deletions(-) create mode 100644 tests/test_model_probe/vae/taesdxl-fp16/config.json create mode 100644 tests/test_model_probe/vae/taesdxl-fp16/diffusion_pytorch_model.fp16.safetensors diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 964cc19f196..b4685caf10d 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -150,7 +150,7 @@ class _DiffusersConfig(ModelConfigBase): """Model config for diffusers-style models.""" format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers - + repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT class LoRAConfig(ModelConfigBase): """Model config for LoRA/Lycoris models.""" @@ -179,7 +179,6 @@ class ControlNetDiffusersConfig(_DiffusersConfig): type: Literal[ModelType.ControlNet] = ModelType.ControlNet format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers - class ControlNetCheckpointConfig(_CheckpointConfig): """Model config for ControlNet models (diffusers version).""" @@ -215,7 +214,6 @@ class MainDiffusersConfig(_DiffusersConfig, _MainConfig): prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon upcast_attention: bool = False - class ONNXSD1Config(_MainConfig): """Model config for ONNX format models based on sd-1.""" diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index cd048d2fe78..ba3ac3dd0cc 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -20,6 +20,7 @@ ModelFormat, ModelType, ModelVariantType, + ModelRepoVariant, SchedulerPredictionType, ) from .hash import FastModelHash @@ -155,6 +156,9 @@ def probe( fields["original_hash"] = fields.get("original_hash") or hash fields["current_hash"] = fields.get("current_hash") or hash + if format_type == ModelFormat.Diffusers: + fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant() + # additional fields needed for main and controlnet models if fields["type"] in [ModelType.Main, ModelType.ControlNet] and fields["format"] == ModelFormat.Checkpoint: fields["config"] = cls._get_checkpoint_config_path( @@ -477,6 +481,20 @@ def get_variant_type(self) -> ModelVariantType: def get_format(self) -> ModelFormat: return ModelFormat("diffusers") + def get_repo_variant(self) -> ModelRepoVariant: + # get all files ending in .bin or .safetensors + weight_files = list(self.model_path.glob('**/*.safetensors')) + weight_files.extend(list(self.model_path.glob('**/*.bin'))) + for x in weight_files: + if ".fp16" in x.suffixes: + return ModelRepoVariant.FP16 + if "openvino_model" in x.name: + return ModelRepoVariant.OPENVINO + if "flax_model" in x.name: + return ModelRepoVariant.FLAX + if x.suffix == ".onnx": + return ModelRepoVariant.ONNX + return ModelRepoVariant.DEFAULT class PipelineFolderProbe(FolderProbeBase): def get_base_type(self) -> BaseModelType: @@ -522,6 +540,7 @@ def get_variant_type(self) -> ModelVariantType: except Exception: pass return ModelVariantType.Normal + class VaeFolderProbe(FolderProbeBase): diff --git a/tests/test_model_probe.py b/tests/test_model_probe.py index 248b7d602fd..415559a64cd 100644 --- a/tests/test_model_probe.py +++ b/tests/test_model_probe.py @@ -3,7 +3,7 @@ import pytest from invokeai.backend import BaseModelType -from invokeai.backend.model_management.model_probe import VaeFolderProbe +from invokeai.backend.model_manager.probe import VaeFolderProbe @pytest.mark.parametrize( @@ -20,3 +20,10 @@ def test_get_base_type(vae_path: str, expected_type: BaseModelType, datadir: Pat probe = VaeFolderProbe(sd1_vae_path) base_type = probe.get_base_type() assert base_type == expected_type + repo_variant = probe.get_repo_variant() + assert repo_variant == 'default' + +def test_repo_variant(datadir: Path): + probe = VaeFolderProbe(datadir / "vae" / "taesdxl-fp16") + repo_variant = probe.get_repo_variant() + assert repo_variant == 'fp16' diff --git a/tests/test_model_probe/vae/taesdxl-fp16/config.json b/tests/test_model_probe/vae/taesdxl-fp16/config.json new file mode 100644 index 00000000000..62f01c3eb44 --- /dev/null +++ b/tests/test_model_probe/vae/taesdxl-fp16/config.json @@ -0,0 +1,37 @@ +{ + "_class_name": "AutoencoderTiny", + "_diffusers_version": "0.20.0.dev0", + "act_fn": "relu", + "decoder_block_out_channels": [ + 64, + 64, + 64, + 64 + ], + "encoder_block_out_channels": [ + 64, + 64, + 64, + 64 + ], + "force_upcast": false, + "in_channels": 3, + "latent_channels": 4, + "latent_magnitude": 3, + "latent_shift": 0.5, + "num_decoder_blocks": [ + 3, + 3, + 3, + 1 + ], + "num_encoder_blocks": [ + 1, + 3, + 3, + 3 + ], + "out_channels": 3, + "scaling_factor": 1.0, + "upsampling_scaling_factor": 2 +} diff --git a/tests/test_model_probe/vae/taesdxl-fp16/diffusion_pytorch_model.fp16.safetensors b/tests/test_model_probe/vae/taesdxl-fp16/diffusion_pytorch_model.fp16.safetensors new file mode 100644 index 00000000000..e69de29bb2d From 5c2884569ebc979738e1802bb9d605f1db635370 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 31 Jan 2024 23:37:59 -0500 Subject: [PATCH 073/100] add ram cache module and support files --- invokeai/backend/model_manager/config.py | 3 + .../backend/model_manager/load/__init__.py | 0 .../backend/model_manager/load/load_base.py | 193 ++++++++++ .../model_manager/load/load_default.py | 168 +++++++++ .../model_manager/load/memory_snapshot.py | 100 ++++++ .../backend/model_manager/load/model_util.py | 109 ++++++ .../model_manager/load/optimizations.py | 30 ++ .../model_manager/load/ram_cache/__init__.py | 0 .../load/ram_cache/ram_cache_base.py | 145 ++++++++ .../load/ram_cache/ram_cache_default.py | 332 ++++++++++++++++++ invokeai/backend/model_manager/load/vae.py | 31 ++ .../backend/model_manager/onnx_runtime.py | 216 ++++++++++++ invokeai/backend/model_manager/probe.py | 8 +- tests/test_model_probe.py | 5 +- 14 files changed, 1334 insertions(+), 6 deletions(-) create mode 100644 invokeai/backend/model_manager/load/__init__.py create mode 100644 invokeai/backend/model_manager/load/load_base.py create mode 100644 invokeai/backend/model_manager/load/load_default.py create mode 100644 invokeai/backend/model_manager/load/memory_snapshot.py create mode 100644 invokeai/backend/model_manager/load/model_util.py create mode 100644 invokeai/backend/model_manager/load/optimizations.py create mode 100644 invokeai/backend/model_manager/load/ram_cache/__init__.py create mode 100644 invokeai/backend/model_manager/load/ram_cache/ram_cache_base.py create mode 100644 invokeai/backend/model_manager/load/ram_cache/ram_cache_default.py create mode 100644 invokeai/backend/model_manager/load/vae.py create mode 100644 invokeai/backend/model_manager/onnx_runtime.py diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index b4685caf10d..338669c873a 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -152,6 +152,7 @@ class _DiffusersConfig(ModelConfigBase): format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT + class LoRAConfig(ModelConfigBase): """Model config for LoRA/Lycoris models.""" @@ -179,6 +180,7 @@ class ControlNetDiffusersConfig(_DiffusersConfig): type: Literal[ModelType.ControlNet] = ModelType.ControlNet format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers + class ControlNetCheckpointConfig(_CheckpointConfig): """Model config for ControlNet models (diffusers version).""" @@ -214,6 +216,7 @@ class MainDiffusersConfig(_DiffusersConfig, _MainConfig): prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon upcast_attention: bool = False + class ONNXSD1Config(_MainConfig): """Model config for ONNX format models based on sd-1.""" diff --git a/invokeai/backend/model_manager/load/__init__.py b/invokeai/backend/model_manager/load/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py new file mode 100644 index 00000000000..7cb7222b717 --- /dev/null +++ b/invokeai/backend/model_manager/load/load_base.py @@ -0,0 +1,193 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +""" +Base class for model loading in InvokeAI. + +Use like this: + + loader = AnyModelLoader(...) + loaded_model = loader.get_model('019ab39adfa1840455') + with loaded_model as model: # context manager moves model into VRAM + # do something with loaded_model +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from logging import Logger +from pathlib import Path +from typing import Any, Callable, Dict, Optional, Type, Union + +import torch +from diffusers import DiffusionPipeline +from injector import inject + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.model_records import ModelRecordServiceBase +from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType +from invokeai.backend.model_manager.convert_cache import ModelConvertCacheBase +from invokeai.backend.model_manager.onnx_runtime import IAIOnnxRuntimeModel +from invokeai.backend.model_manager.ram_cache import ModelCacheBase + +AnyModel = Union[DiffusionPipeline, torch.nn.Module, IAIOnnxRuntimeModel] + + +class ModelLockerBase(ABC): + """Base class for the model locker used by the loader.""" + + @abstractmethod + def lock(self) -> None: + """Lock the contained model and move it into VRAM.""" + pass + + @abstractmethod + def unlock(self) -> None: + """Unlock the contained model, and remove it from VRAM.""" + pass + + @property + @abstractmethod + def model(self) -> AnyModel: + """Return the model.""" + pass + + +@dataclass +class LoadedModel: + """Context manager object that mediates transfer from RAM<->VRAM.""" + + config: AnyModelConfig + locker: ModelLockerBase + + def __enter__(self) -> AnyModel: # I think load_file() always returns a dict + """Context entry.""" + self.locker.lock() + return self.model + + def __exit__(self, *args: Any, **kwargs: Any) -> None: + """Context exit.""" + self.locker.unlock() + + @property + def model(self) -> AnyModel: + """Return the model without locking it.""" + return self.locker.model() + + +class ModelLoaderBase(ABC): + """Abstract base class for loading models into RAM/VRAM.""" + + @abstractmethod + def __init__( + self, + app_config: InvokeAIAppConfig, + logger: Logger, + ram_cache: ModelCacheBase, + convert_cache: ModelConvertCacheBase, + ): + """Initialize the loader.""" + pass + + @abstractmethod + def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: + """ + Return a model given its key. + + Given a model key identified in the model configuration backend, + return a ModelInfo object that can be used to retrieve the model. + + :param model_config: Model configuration, as returned by ModelConfigRecordStore + :param submodel_type: an ModelType enum indicating the portion of + the model to retrieve (e.g. ModelType.Vae) + """ + pass + + @abstractmethod + def get_size_fs( + self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None + ) -> int: + """Return size in bytes of the model, calculated before loading.""" + pass + + +# TO DO: Better name? +class AnyModelLoader: + """This class manages the model loaders and invokes the correct one to load a model of given base and type.""" + + # this tracks the loader subclasses + _registry: Dict[str, Type[ModelLoaderBase]] = {} + + @inject + def __init__( + self, + store: ModelRecordServiceBase, + app_config: InvokeAIAppConfig, + logger: Logger, + ram_cache: ModelCacheBase, + convert_cache: ModelConvertCacheBase, + ): + """Store the provided ModelRecordServiceBase and empty the registry.""" + self._store = store + self._app_config = app_config + self._logger = logger + self._ram_cache = ram_cache + self._convert_cache = convert_cache + + def get_model(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel: + """ + Return a model given its key. + + Given a model key identified in the model configuration backend, + return a ModelInfo object that can be used to retrieve the model. + + :param key: model key, as known to the config backend + :param submodel_type: an ModelType enum indicating the portion of + the model to retrieve (e.g. ModelType.Vae) + """ + model_config = self._store.get_model(key) + implementation = self.__class__.get_implementation( + base=model_config.base, type=model_config.type, format=model_config.format + ) + return implementation( + app_config=self._app_config, + logger=self._logger, + ram_cache=self._ram_cache, + convert_cache=self._convert_cache, + ).load_model(model_config, submodel_type) + + @staticmethod + def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) -> str: + return "-".join([base.value, type.value, format.value]) + + @classmethod + def get_implementation(cls, base: BaseModelType, type: ModelType, format: ModelFormat) -> Type[ModelLoaderBase]: + """Get subclass of ModelLoaderBase registered to handle base and type.""" + key1 = cls._to_registry_key(base, type, format) # for a specific base type + key2 = cls._to_registry_key(BaseModelType.Any, type, format) # with wildcard Any + implementation = cls._registry.get(key1) or cls._registry.get(key2) + if not implementation: + raise NotImplementedError( + "No subclass of LoadedModel is registered for base={base}, type={type}, format={format}" + ) + return implementation + + @classmethod + def register( + cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any + ) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]: + """Define a decorator which registers the subclass of loader.""" + + def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]: + print("Registering class", subclass.__name__) + key = cls._to_registry_key(base, type, format) + cls._registry[key] = subclass + return subclass + + return decorator + + +# in _init__.py will call something like +# def configure_loader_dependencies(binder): +# binder.bind(ModelRecordServiceBase, ApiDependencies.invoker.services.model_records, scope=singleton) +# binder.bind(InvokeAIAppConfig, ApiDependencies.invoker.services.configuration, scope=singleton) +# etc +# injector = Injector(configure_loader_dependencies) +# loader = injector.get(ModelFactory) diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py new file mode 100644 index 00000000000..eb2d432aaae --- /dev/null +++ b/invokeai/backend/model_manager/load/load_default.py @@ -0,0 +1,168 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Default implementation of model loading in InvokeAI.""" + +import sys +from logging import Logger +from pathlib import Path +from typing import Any, Dict, Optional, Tuple + +from diffusers import ModelMixin +from diffusers.configuration_utils import ConfigMixin +from injector import inject + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend.model_manager import AnyModelConfig, InvalidModelConfigException, ModelRepoVariant, SubModelType +from invokeai.backend.model_manager.convert_cache import ModelConvertCacheBase +from invokeai.backend.model_manager.load.load_base import AnyModel, LoadedModel, ModelLoaderBase +from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs +from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init +from invokeai.backend.model_manager.ram_cache import ModelCacheBase, ModelLockerBase +from invokeai.backend.util.devices import choose_torch_device, torch_dtype + + +class ConfigLoader(ConfigMixin): + """Subclass of ConfigMixin for loading diffusers configuration files.""" + + @classmethod + def load_config(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]: + """Load a diffusrs ConfigMixin configuration.""" + cls.config_name = kwargs.pop("config_name") + # Diffusers doesn't provide typing info + return super().load_config(*args, **kwargs) # type: ignore + + +# TO DO: The loader is not thread safe! +class ModelLoader(ModelLoaderBase): + """Default implementation of ModelLoaderBase.""" + + @inject # can inject instances of each of the classes in the call signature + def __init__( + self, + app_config: InvokeAIAppConfig, + logger: Logger, + ram_cache: ModelCacheBase, + convert_cache: ModelConvertCacheBase, + ): + """Initialize the loader.""" + self._app_config = app_config + self._logger = logger + self._ram_cache = ram_cache + self._convert_cache = convert_cache + self._torch_dtype = torch_dtype(choose_torch_device()) + self._size: Optional[int] = None # model size + + def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: + """ + Return a model given its configuration. + + Given a model's configuration as returned by the ModelRecordConfigStore service, + return a LoadedModel object that can be used for inference. + + :param model config: Configuration record for this model + :param submodel_type: an ModelType enum indicating the portion of + the model to retrieve (e.g. ModelType.Vae) + """ + if model_config.type == "main" and not submodel_type: + raise InvalidModelConfigException("submodel_type is required when loading a main model") + + model_path, is_submodel_override = self._get_model_path(model_config, submodel_type) + if is_submodel_override: + submodel_type = None + + if not model_path.exists(): + raise InvalidModelConfigException(f"Files for model 'model_config.name' not found at {model_path}") + + model_path = self._convert_if_needed(model_config, model_path, submodel_type) + locker = self._load_if_needed(model_config, model_path, submodel_type) + return LoadedModel(config=model_config, locker=locker) + + # IMPORTANT: This needs to be overridden in the StableDiffusion subclass so as to handle vae overrides + # and submodels!!!! + def _get_model_path( + self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None + ) -> Tuple[Path, bool]: + model_base = self._app_config.models_path + return ((model_base / config.path).resolve(), False) + + def _convert_if_needed( + self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None + ) -> Path: + if not self._needs_conversion(config): + return model_path + + self._convert_cache.make_room(self._size or self.get_size_fs(config, model_path, submodel_type)) + cache_path: Path = self._convert_cache.cache_path(config.key) + if cache_path.exists(): + return cache_path + + self._convert_model(model_path, cache_path) + return cache_path + + def _needs_conversion(self, config: AnyModelConfig) -> bool: + return False + + def _load_if_needed( + self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None + ) -> ModelLockerBase: + # TO DO: This is not thread safe! + if self._ram_cache.exists(config.key, submodel_type): + return self._ram_cache.get(config.key, submodel_type) + + model_variant = getattr(config, "repo_variant", None) + self._ram_cache.make_room(self.get_size_fs(config, model_path, submodel_type)) + + # This is where the model is actually loaded! + with skip_torch_weight_init(): + loaded_model = self._load_model(model_path, model_variant=model_variant, submodel_type=submodel_type) + + self._ram_cache.put( + config.key, + submodel_type=submodel_type, + model=loaded_model, + ) + + return self._ram_cache.get(config.key, submodel_type) + + def get_size_fs( + self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None + ) -> int: + """Get the size of the model on disk.""" + return calc_model_size_by_fs( + model_path=model_path, + subfolder=submodel_type.value if submodel_type else None, + variant=config.repo_variant if hasattr(config, "repo_variant") else None, + ) + + def _convert_model(self, model_path: Path, cache_path: Path) -> None: + raise NotImplementedError + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + raise NotImplementedError + + def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]: + return ConfigLoader.load_config(model_path, config_name=config_name) + + # TO DO: Add exception handling + def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type + if module in ["diffusers", "transformers"]: + res_type = sys.modules[module] + else: + res_type = sys.modules["diffusers"].pipelines + result: ModelMixin = getattr(res_type, class_name) + return result + + # TO DO: Add exception handling + def _get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin: + if submodel_type: + config = self._load_diffusers_config(model_path, config_name="model_index.json") + module, class_name = config[submodel_type.value] + return self._hf_definition_to_type(module=module, class_name=class_name) + else: + config = self._load_diffusers_config(model_path, config_name="config.json") + class_name = config["_class_name"] + return self._hf_definition_to_type(module="diffusers", class_name=class_name) diff --git a/invokeai/backend/model_manager/load/memory_snapshot.py b/invokeai/backend/model_manager/load/memory_snapshot.py new file mode 100644 index 00000000000..504829a4271 --- /dev/null +++ b/invokeai/backend/model_manager/load/memory_snapshot.py @@ -0,0 +1,100 @@ +import gc +from typing import Optional + +import psutil +import torch +from typing_extensions import Self + +from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2 + +GB = 2**30 # 1 GB + + +class MemorySnapshot: + """A snapshot of RAM and VRAM usage. All values are in bytes.""" + + def __init__(self, process_ram: int, vram: Optional[int], malloc_info: Optional[Struct_mallinfo2]): + """Initialize a MemorySnapshot. + + Most of the time, `MemorySnapshot` will be constructed with `MemorySnapshot.capture()`. + + Args: + process_ram (int): CPU RAM used by the current process. + vram (Optional[int]): VRAM used by torch. + malloc_info (Optional[Struct_mallinfo2]): Malloc info obtained from LibcUtil. + """ + self.process_ram = process_ram + self.vram = vram + self.malloc_info = malloc_info + + @classmethod + def capture(cls, run_garbage_collector: bool = True) -> Self: + """Capture and return a MemorySnapshot. + + Note: This function has significant overhead, particularly if `run_garbage_collector == True`. + + Args: + run_garbage_collector (bool, optional): If true, gc.collect() will be run before checking the process RAM + usage. Defaults to True. + + Returns: + MemorySnapshot + """ + if run_garbage_collector: + gc.collect() + + # According to the psutil docs (https://psutil.readthedocs.io/en/latest/#psutil.Process.memory_info), rss is + # supported on all platforms. + process_ram = psutil.Process().memory_info().rss + + if torch.cuda.is_available(): + vram = torch.cuda.memory_allocated() + else: + # TODO: We could add support for mps.current_allocated_memory() as well. Leaving out for now until we have + # time to test it properly. + vram = None + + try: + malloc_info = LibcUtil().mallinfo2() # type: ignore + except (OSError, AttributeError): + # OSError: This is expected in environments that do not have the 'libc.so.6' shared library. + # AttributeError: This is expected in environments that have `libc.so.6` but do not have the `mallinfo2` (e.g. glibc < 2.33) + # TODO: Does `mallinfo` work? + malloc_info = None + + return cls(process_ram, vram, malloc_info) + + +def get_pretty_snapshot_diff(snapshot_1: Optional[MemorySnapshot], snapshot_2: Optional[MemorySnapshot]) -> str: + """Get a pretty string describing the difference between two `MemorySnapshot`s.""" + + def get_msg_line(prefix: str, val1: int, val2: int) -> str: + diff = val2 - val1 + return f"{prefix: <30} ({(diff/GB):+5.3f}): {(val1/GB):5.3f}GB -> {(val2/GB):5.3f}GB\n" + + msg = "" + + if snapshot_1 is None or snapshot_2 is None: + return msg + + msg += get_msg_line("Process RAM", snapshot_1.process_ram, snapshot_2.process_ram) + + if snapshot_1.malloc_info is not None and snapshot_2.malloc_info is not None: + msg += get_msg_line("libc mmap allocated", snapshot_1.malloc_info.hblkhd, snapshot_2.malloc_info.hblkhd) + + msg += get_msg_line("libc arena used", snapshot_1.malloc_info.uordblks, snapshot_2.malloc_info.uordblks) + + msg += get_msg_line("libc arena free", snapshot_1.malloc_info.fordblks, snapshot_2.malloc_info.fordblks) + + libc_total_allocated_1 = snapshot_1.malloc_info.arena + snapshot_1.malloc_info.hblkhd + libc_total_allocated_2 = snapshot_2.malloc_info.arena + snapshot_2.malloc_info.hblkhd + msg += get_msg_line("libc total allocated", libc_total_allocated_1, libc_total_allocated_2) + + libc_total_used_1 = snapshot_1.malloc_info.uordblks + snapshot_1.malloc_info.hblkhd + libc_total_used_2 = snapshot_2.malloc_info.uordblks + snapshot_2.malloc_info.hblkhd + msg += get_msg_line("libc total used", libc_total_used_1, libc_total_used_2) + + if snapshot_1.vram is not None and snapshot_2.vram is not None: + msg += get_msg_line("VRAM", snapshot_1.vram, snapshot_2.vram) + + return msg diff --git a/invokeai/backend/model_manager/load/model_util.py b/invokeai/backend/model_manager/load/model_util.py new file mode 100644 index 00000000000..18407cbca2e --- /dev/null +++ b/invokeai/backend/model_manager/load/model_util.py @@ -0,0 +1,109 @@ +# Copyright (c) 2024 The InvokeAI Development Team +"""Various utility functions needed by the loader and caching system.""" + +import json +from pathlib import Path +from typing import Optional, Union + +import torch +from diffusers import DiffusionPipeline + +from invokeai.backend.model_manager.onnx_runtime import IAIOnnxRuntimeModel + + +def calc_model_size_by_data(model: Union[DiffusionPipeline, torch.nn.Module, IAIOnnxRuntimeModel]) -> int: + """Get size of a model in memory in bytes.""" + if isinstance(model, DiffusionPipeline): + return _calc_pipeline_by_data(model) + elif isinstance(model, torch.nn.Module): + return _calc_model_by_data(model) + elif isinstance(model, IAIOnnxRuntimeModel): + return _calc_onnx_model_by_data(model) + else: + return 0 + + +def _calc_pipeline_by_data(pipeline: DiffusionPipeline) -> int: + res = 0 + assert hasattr(pipeline, "components") + for submodel_key in pipeline.components.keys(): + submodel = getattr(pipeline, submodel_key) + if submodel is not None and isinstance(submodel, torch.nn.Module): + res += _calc_model_by_data(submodel) + return res + + +def _calc_model_by_data(model: torch.nn.Module) -> int: + mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()]) + mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()]) + mem: int = mem_params + mem_bufs # in bytes + return mem + + +def _calc_onnx_model_by_data(model: IAIOnnxRuntimeModel) -> int: + tensor_size = model.tensors.size() * 2 # The session doubles this + mem = tensor_size # in bytes + return mem + + +def calc_model_size_by_fs(model_path: Path, subfolder: Optional[str] = None, variant: Optional[str] = None) -> int: + """Estimate the size of a model on disk in bytes.""" + if subfolder is not None: + model_path = model_path / subfolder + + # this can happen when, for example, the safety checker is not downloaded. + if not model_path.exists(): + return 0 + + all_files = [f for f in model_path.iterdir() if (model_path / f).is_file()] + + fp16_files = {f for f in all_files if ".fp16." in f.name or ".fp16-" in f.name} + bit8_files = {f for f in all_files if ".8bit." in f.name or ".8bit-" in f.name} + other_files = set(all_files) - fp16_files - bit8_files + + if variant is None: + files = other_files + elif variant == "fp16": + files = fp16_files + elif variant == "8bit": + files = bit8_files + else: + raise NotImplementedError(f"Unknown variant: {variant}") + + # try read from index if exists + index_postfix = ".index.json" + if variant is not None: + index_postfix = f".index.{variant}.json" + + for file in files: + if not file.name.endswith(index_postfix): + continue + try: + with open(model_path / file, "r") as f: + index_data = json.loads(f.read()) + return int(index_data["metadata"]["total_size"]) + except Exception: + pass + + # calculate files size if there is no index file + formats = [ + (".safetensors",), # safetensors + (".bin",), # torch + (".onnx", ".pb"), # onnx + (".msgpack",), # flax + (".ckpt",), # tf + (".h5",), # tf2 + ] + + for file_format in formats: + model_files = [f for f in files if f.suffix in file_format] + if len(model_files) == 0: + continue + + model_size = 0 + for model_file in model_files: + file_stats = (model_path / model_file).stat() + model_size += file_stats.st_size + return model_size + + return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu diff --git a/invokeai/backend/model_manager/load/optimizations.py b/invokeai/backend/model_manager/load/optimizations.py new file mode 100644 index 00000000000..a46d262175f --- /dev/null +++ b/invokeai/backend/model_manager/load/optimizations.py @@ -0,0 +1,30 @@ +from contextlib import contextmanager + +import torch + + +def _no_op(*args, **kwargs): + pass + + +@contextmanager +def skip_torch_weight_init(): + """A context manager that monkey-patches several of the common torch layers (torch.nn.Linear, torch.nn.Conv1d, etc.) + to skip weight initialization. + + By default, `torch.nn.Linear` and `torch.nn.ConvNd` layers initialize their weights (according to a particular + distribution) when __init__ is called. This weight initialization step can take a significant amount of time, and is + completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager + monkey-patches common torch layers to skip the weight initialization step. + """ + torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding] + saved_functions = [m.reset_parameters for m in torch_modules] + + try: + for torch_module in torch_modules: + torch_module.reset_parameters = _no_op + + yield None + finally: + for torch_module, saved_function in zip(torch_modules, saved_functions, strict=True): + torch_module.reset_parameters = saved_function diff --git a/invokeai/backend/model_manager/load/ram_cache/__init__.py b/invokeai/backend/model_manager/load/ram_cache/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/invokeai/backend/model_manager/load/ram_cache/ram_cache_base.py b/invokeai/backend/model_manager/load/ram_cache/ram_cache_base.py new file mode 100644 index 00000000000..cd80d1e78b2 --- /dev/null +++ b/invokeai/backend/model_manager/load/ram_cache/ram_cache_base.py @@ -0,0 +1,145 @@ +# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team +# TODO: Add Stalker's proper name to copyright +""" +Manage a RAM cache of diffusion/transformer models for fast switching. +They are moved between GPU VRAM and CPU RAM as necessary. If the cache +grows larger than a preset maximum, then the least recently used +model will be cleared and (re)loaded from disk when next needed. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from logging import Logger +from typing import Dict, Optional + +import torch + +from invokeai.backend.model_manager import SubModelType +from invokeai.backend.model_manager.load.load_base import AnyModel, ModelLockerBase + + +@dataclass +class CacheStats(object): + """Data object to record statistics on cache hits/misses.""" + + hits: int = 0 # cache hits + misses: int = 0 # cache misses + high_watermark: int = 0 # amount of cache used + in_cache: int = 0 # number of models in cache + cleared: int = 0 # number of models cleared to make space + cache_size: int = 0 # total size of cache + loaded_model_sizes: Dict[str, int] = field(default_factory=dict) + + +@dataclass +class CacheRecord: + """Elements of the cache.""" + + key: str + model: AnyModel + size: int + _locks: int = 0 + + def lock(self) -> None: + """Lock this record.""" + self._locks += 1 + + def unlock(self) -> None: + """Unlock this record.""" + self._locks -= 1 + assert self._locks >= 0 + + @property + def locked(self) -> bool: + """Return true if record is locked.""" + return self._locks > 0 + + +class ModelCacheBase(ABC): + """Virtual base class for RAM model cache.""" + + @property + @abstractmethod + def storage_device(self) -> torch.device: + """Return the storage device (e.g. "CPU" for RAM).""" + pass + + @property + @abstractmethod + def execution_device(self) -> torch.device: + """Return the exection device (e.g. "cuda" for VRAM).""" + pass + + @property + @abstractmethod + def lazy_offloading(self) -> bool: + """Return true if the cache is configured to lazily offload models in VRAM.""" + pass + + @abstractmethod + def offload_unlocked_models(self) -> None: + """Offload from VRAM any models not actively in use.""" + pass + + @abstractmethod + def move_model_to_device(self, cache_entry: CacheRecord, device: torch.device) -> None: + """Move model into the indicated device.""" + pass + + @property + @abstractmethod + def logger(self) -> Logger: + """Return the logger used by the cache.""" + pass + + @abstractmethod + def make_room(self, size: int) -> None: + """Make enough room in the cache to accommodate a new model of indicated size.""" + pass + + @abstractmethod + def put( + self, + key: str, + model: AnyModel, + submodel_type: Optional[SubModelType] = None, + ) -> None: + """Store model under key and optional submodel_type.""" + pass + + @abstractmethod + def get( + self, + key: str, + submodel_type: Optional[SubModelType] = None, + ) -> ModelLockerBase: + """ + Retrieve model locker object using key and optional submodel_type. + + This may return an UnknownModelException if the model is not in the cache. + """ + pass + + @abstractmethod + def exists( + self, + key: str, + submodel_type: Optional[SubModelType] = None, + ) -> bool: + """Return true if the model identified by key and submodel_type is in the cache.""" + pass + + @abstractmethod + def cache_size(self) -> int: + """Get the total size of the models currently cached.""" + pass + + @abstractmethod + def get_stats(self) -> CacheStats: + """Return cache hit/miss/size statistics.""" + pass + + @abstractmethod + def print_cuda_stats(self) -> None: + """Log debugging information on CUDA usage.""" + pass diff --git a/invokeai/backend/model_manager/load/ram_cache/ram_cache_default.py b/invokeai/backend/model_manager/load/ram_cache/ram_cache_default.py new file mode 100644 index 00000000000..bd43e978c83 --- /dev/null +++ b/invokeai/backend/model_manager/load/ram_cache/ram_cache_default.py @@ -0,0 +1,332 @@ +# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team +# TODO: Add Stalker's proper name to copyright +""" +Manage a RAM cache of diffusion/transformer models for fast switching. +They are moved between GPU VRAM and CPU RAM as necessary. If the cache +grows larger than a preset maximum, then the least recently used +model will be cleared and (re)loaded from disk when next needed. + +The cache returns context manager generators designed to load the +model into the GPU within the context, and unload outside the +context. Use like this: + + cache = ModelCache(max_cache_size=7.5) + with cache.get_model('runwayml/stable-diffusion-1-5') as SD1, + cache.get_model('stabilityai/stable-diffusion-2') as SD2: + do_something_in_GPU(SD1,SD2) + + +""" + +import math +import time +from contextlib import suppress +from logging import Logger +from typing import Any, Dict, List, Optional + +import torch + +from invokeai.app.services.model_records import UnknownModelException +from invokeai.backend.model_manager import SubModelType +from invokeai.backend.model_manager.load.load_base import AnyModel, ModelLockerBase +from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff +from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data +from invokeai.backend.model_manager.load.ram_cache.ram_cache_base import CacheRecord, CacheStats, ModelCacheBase +from invokeai.backend.util.devices import choose_torch_device +from invokeai.backend.util.logging import InvokeAILogger + +if choose_torch_device() == torch.device("mps"): + from torch import mps + +# Maximum size of the cache, in gigs +# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously +DEFAULT_MAX_CACHE_SIZE = 6.0 + +# amount of GPU memory to hold in reserve for use by generations (GB) +DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75 + +# actual size of a gig +GIG = 1073741824 + +# Size of a MB in bytes. +MB = 2**20 + + +class ModelCache(ModelCacheBase): + """Implementation of ModelCacheBase.""" + + def __init__( + self, + max_cache_size: float = DEFAULT_MAX_CACHE_SIZE, + max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE, + execution_device: torch.device = torch.device("cuda"), + storage_device: torch.device = torch.device("cpu"), + precision: torch.dtype = torch.float16, + sequential_offload: bool = False, + lazy_offloading: bool = True, + sha_chunksize: int = 16777216, + log_memory_usage: bool = False, + logger: Optional[Logger] = None, + ): + """ + Initialize the model RAM cache. + + :param max_cache_size: Maximum size of the RAM cache [6.0 GB] + :param execution_device: Torch device to load active model into [torch.device('cuda')] + :param storage_device: Torch device to save inactive model in [torch.device('cpu')] + :param precision: Precision for loaded models [torch.float16] + :param lazy_offloading: Keep model in VRAM until another model needs to be loaded + :param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially + :param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache + operation, and the result will be logged (at debug level). There is a time cost to capturing the memory + snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's + behaviour. + """ + # allow lazy offloading only when vram cache enabled + self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0 + self._precision: torch.dtype = precision + self._max_cache_size: float = max_cache_size + self._max_vram_cache_size: float = max_vram_cache_size + self._execution_device: torch.device = execution_device + self._storage_device: torch.device = storage_device + self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__) + self._log_memory_usage = log_memory_usage + + # used for stats collection + self.stats = None + + self._cached_models: Dict[str, CacheRecord] = {} + self._cache_stack: List[str] = [] + + class ModelLocker(ModelLockerBase): + """Internal class that mediates movement in and out of GPU.""" + + def __init__(self, cache: ModelCacheBase, cache_entry: CacheRecord): + """ + Initialize the model locker. + + :param cache: The ModelCache object + :param cache_entry: The entry in the model cache + """ + self._cache = cache + self._cache_entry = cache_entry + + @property + def model(self) -> AnyModel: + """Return the model without moving it around.""" + return self._cache_entry.model + + def lock(self) -> Any: + """Move the model into the execution device (GPU) and lock it.""" + if not hasattr(self.model, "to"): + return self.model + + # NOTE that the model has to have the to() method in order for this code to move it into GPU! + self._cache_entry.lock() + + try: + if self._cache.lazy_offloading: + self._cache.offload_unlocked_models() + + self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device) + + self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}") + self._cache.print_cuda_stats() + + except Exception: + self._cache_entry.unlock() + raise + return self.model + + def unlock(self) -> None: + """Call upon exit from context.""" + if not hasattr(self.model, "to"): + return + + self._cache_entry.unlock() + if not self._cache.lazy_offloading: + self._cache.offload_unlocked_models() + self._cache.print_cuda_stats() + + @property + def logger(self) -> Logger: + """Return the logger used by the cache.""" + return self._logger + + @property + def lazy_offloading(self) -> bool: + """Return true if the cache is configured to lazily offload models in VRAM.""" + return self._lazy_offloading + + @property + def storage_device(self) -> torch.device: + """Return the storage device (e.g. "CPU" for RAM).""" + return self._storage_device + + @property + def execution_device(self) -> torch.device: + """Return the exection device (e.g. "cuda" for VRAM).""" + return self._execution_device + + def cache_size(self) -> int: + """Get the total size of the models currently cached.""" + total = 0 + for cache_record in self._cached_models.values(): + total += cache_record.size + return total + + def exists( + self, + key: str, + submodel_type: Optional[SubModelType] = None, + ) -> bool: + """Return true if the model identified by key and submodel_type is in the cache.""" + key = self._make_cache_key(key, submodel_type) + return key in self._cached_models + + def put( + self, + key: str, + model: AnyModel, + submodel_type: Optional[SubModelType] = None, + ) -> None: + """Store model under key and optional submodel_type.""" + key = self._make_cache_key(key, submodel_type) + assert key not in self._cached_models + + loaded_model_size = calc_model_size_by_data(model) + cache_record = CacheRecord(key, model, loaded_model_size) + self._cached_models[key] = cache_record + self._cache_stack.append(key) + + def get( + self, + key: str, + submodel_type: Optional[SubModelType] = None, + ) -> ModelLockerBase: + """ + Retrieve model using key and optional submodel_type. + + This may return an UnknownModelException if the model is not in the cache. + """ + key = self._make_cache_key(key, submodel_type) + if key not in self._cached_models: + raise UnknownModelException + + # this moves the entry to the top (right end) of the stack + with suppress(Exception): + self._cache_stack.remove(key) + self._cache_stack.append(key) + cache_entry = self._cached_models[key] + return self.ModelLocker( + cache=self, + cache_entry=cache_entry, + ) + + def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]: + if self._log_memory_usage: + return MemorySnapshot.capture() + return None + + def _make_cache_key(self, model_key: str, submodel_type: Optional[SubModelType] = None) -> str: + if submodel_type: + return f"{model_key}:{submodel_type.value}" + else: + return model_key + + def offload_unlocked_models(self) -> None: + """Move any unused models from VRAM.""" + reserved = self._max_vram_cache_size * GIG + vram_in_use = torch.cuda.memory_allocated() + self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB") + for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size): + if vram_in_use <= reserved: + break + if not cache_entry.locked: + self.move_model_to_device(cache_entry, self.storage_device) + + vram_in_use = torch.cuda.memory_allocated() + self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB") + + torch.cuda.empty_cache() + if choose_torch_device() == torch.device("mps"): + mps.empty_cache() + + # TO DO: Only reason to pass the CacheRecord rather than the model is to get the key and size + # for printing debugging messages. Revisit whether this is necessary + def move_model_to_device(self, cache_entry: CacheRecord, target_device: torch.device) -> None: + """Move model into the indicated device.""" + # These attributes are not in the base class but in derived classes + assert hasattr(cache_entry.model, "device") + assert hasattr(cache_entry.model, "to") + + source_device = cache_entry.model.device + + # Note: We compare device types only so that 'cuda' == 'cuda:0'. This would need to be revised to support + # multi-GPU. + if torch.device(source_device).type == torch.device(target_device).type: + return + + start_model_to_time = time.time() + snapshot_before = self._capture_memory_snapshot() + cache_entry.model.to(target_device) + snapshot_after = self._capture_memory_snapshot() + end_model_to_time = time.time() + self.logger.debug( + f"Moved model '{cache_entry.key}' from {source_device} to" + f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s.\n" + f"Estimated model size: {(cache_entry.size/GIG):.3f} GB.\n" + f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" + ) + + if ( + snapshot_before is not None + and snapshot_after is not None + and snapshot_before.vram is not None + and snapshot_after.vram is not None + ): + vram_change = abs(snapshot_before.vram - snapshot_after.vram) + + # If the estimated model size does not match the change in VRAM, log a warning. + if not math.isclose( + vram_change, + cache_entry.size, + rel_tol=0.1, + abs_tol=10 * MB, + ): + self.logger.debug( + f"Moving model '{cache_entry.key}' from {source_device} to" + f" {target_device} caused an unexpected change in VRAM usage. The model's" + " estimated size may be incorrect. Estimated model size:" + f" {(cache_entry.size/GIG):.3f} GB.\n" + f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" + ) + + def print_cuda_stats(self) -> None: + """Log CUDA diagnostics.""" + vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG) + ram = "%4.2fG" % self.cache_size() + + cached_models = 0 + loaded_models = 0 + locked_models = 0 + for cache_record in self._cached_models.values(): + cached_models += 1 + assert hasattr(cache_record.model, "device") + if cache_record.model.device is self.storage_device: + loaded_models += 1 + if cache_record.locked: + locked_models += 1 + + self.logger.debug( + f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ =" + f" {cached_models}/{loaded_models}/{locked_models}" + ) + + def get_stats(self) -> CacheStats: + """Return cache hit/miss/size statistics.""" + raise NotImplementedError + + def make_room(self, size: int) -> None: + """Make enough room in the cache to accommodate a new model of indicated size.""" + raise NotImplementedError diff --git a/invokeai/backend/model_manager/load/vae.py b/invokeai/backend/model_manager/load/vae.py new file mode 100644 index 00000000000..a6cbe241e1e --- /dev/null +++ b/invokeai/backend/model_manager/load/vae.py @@ -0,0 +1,31 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Class for VAE model loading in InvokeAI.""" + +from pathlib import Path +from typing import Dict, Optional + +import torch + +from invokeai.backend.model_manager import BaseModelType, ModelFormat, ModelRepoVariant, ModelType, SubModelType +from invokeai.backend.model_manager.load.load_base import AnyModelLoader +from invokeai.backend.model_manager.load.load_default import ModelLoader + + +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers) +class VaeDiffusersModel(ModelLoader): + """Class to load VAE models.""" + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> Dict[str, torch.Tensor]: + if submodel_type is not None: + raise Exception("There are no submodels in VAEs") + vae_class = self._get_hf_load_class(model_path) + variant = model_variant.value if model_variant else "" + result: Dict[str, torch.Tensor] = vae_class.from_pretrained( + model_path, torch_dtype=self._torch_dtype, variant=variant + ) # type: ignore + return result diff --git a/invokeai/backend/model_manager/onnx_runtime.py b/invokeai/backend/model_manager/onnx_runtime.py new file mode 100644 index 00000000000..f79fa015692 --- /dev/null +++ b/invokeai/backend/model_manager/onnx_runtime.py @@ -0,0 +1,216 @@ +# Copyright (c) 2024 The InvokeAI Development Team +import os +import sys +from pathlib import Path +from typing import Any, List, Optional, Tuple, Union + +import numpy as np +import onnx +from onnx import numpy_helper +from onnxruntime import InferenceSession, SessionOptions, get_available_providers + +ONNX_WEIGHTS_NAME = "model.onnx" + + +# NOTE FROM LS: This was copied from Stalker's original implementation. +# I have not yet gone through and fixed all the type hints +class IAIOnnxRuntimeModel: + class _tensor_access: + def __init__(self, model): # type: ignore + self.model = model + self.indexes = {} + for idx, obj in enumerate(self.model.proto.graph.initializer): + self.indexes[obj.name] = idx + + def __getitem__(self, key: str): # type: ignore + value = self.model.proto.graph.initializer[self.indexes[key]] + return numpy_helper.to_array(value) + + def __setitem__(self, key: str, value: np.ndarray): # type: ignore + new_node = numpy_helper.from_array(value) + # set_external_data(new_node, location="in-memory-location") + new_node.name = key + # new_node.ClearField("raw_data") + del self.model.proto.graph.initializer[self.indexes[key]] + self.model.proto.graph.initializer.insert(self.indexes[key], new_node) + # self.model.data[key] = OrtValue.ortvalue_from_numpy(value) + + # __delitem__ + + def __contains__(self, key: str) -> bool: + return self.indexes[key] in self.model.proto.graph.initializer + + def items(self) -> List[Tuple[str, Any]]: # fixme + raise NotImplementedError("tensor.items") + # return [(obj.name, obj) for obj in self.raw_proto] + + def keys(self) -> List[str]: + return list(self.indexes.keys()) + + def values(self) -> List[Any]: # fixme + raise NotImplementedError("tensor.values") + # return [obj for obj in self.raw_proto] + + def size(self) -> int: + bytesSum = 0 + for node in self.model.proto.graph.initializer: + bytesSum += sys.getsizeof(node.raw_data) + return bytesSum + + class _access_helper: + def __init__(self, raw_proto): # type: ignore + self.indexes = {} + self.raw_proto = raw_proto + for idx, obj in enumerate(raw_proto): + self.indexes[obj.name] = idx + + def __getitem__(self, key: str): # type: ignore + return self.raw_proto[self.indexes[key]] + + def __setitem__(self, key: str, value): # type: ignore + index = self.indexes[key] + del self.raw_proto[index] + self.raw_proto.insert(index, value) + + # __delitem__ + + def __contains__(self, key: str) -> bool: + return key in self.indexes + + def items(self) -> List[Tuple[str, Any]]: + return [(obj.name, obj) for obj in self.raw_proto] + + def keys(self) -> List[str]: + return list(self.indexes.keys()) + + def values(self) -> List[Any]: # fixme + return list(self.raw_proto) + + def __init__(self, model_path: str, provider: Optional[str]): + self.path = model_path + self.session = None + self.provider = provider + """ + self.data_path = self.path + "_data" + if not os.path.exists(self.data_path): + print(f"Moving model tensors to separate file: {self.data_path}") + tmp_proto = onnx.load(model_path, load_external_data=True) + onnx.save_model(tmp_proto, self.path, save_as_external_data=True, all_tensors_to_one_file=True, location=os.path.basename(self.data_path), size_threshold=1024, convert_attribute=False) + del tmp_proto + gc.collect() + + self.proto = onnx.load(model_path, load_external_data=False) + """ + + self.proto = onnx.load(model_path, load_external_data=True) + # self.data = dict() + # for tensor in self.proto.graph.initializer: + # name = tensor.name + + # if tensor.HasField("raw_data"): + # npt = numpy_helper.to_array(tensor) + # orv = OrtValue.ortvalue_from_numpy(npt) + # # self.data[name] = orv + # # set_external_data(tensor, location="in-memory-location") + # tensor.name = name + # # tensor.ClearField("raw_data") + + self.nodes = self._access_helper(self.proto.graph.node) # type: ignore + # self.initializers = self._access_helper(self.proto.graph.initializer) + # print(self.proto.graph.input) + # print(self.proto.graph.initializer) + + self.tensors = self._tensor_access(self) # type: ignore + + # TODO: integrate with model manager/cache + def create_session(self, height=None, width=None): + if self.session is None or self.session_width != width or self.session_height != height: + # onnx.save(self.proto, "tmp.onnx") + # onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False) + # TODO: something to be able to get weight when they already moved outside of model proto + # (trimmed_model, external_data) = buffer_external_data_tensors(self.proto) + sess = SessionOptions() + # self._external_data.update(**external_data) + # sess.add_external_initializers(list(self.data.keys()), list(self.data.values())) + # sess.enable_profiling = True + + # sess.intra_op_num_threads = 1 + # sess.inter_op_num_threads = 1 + # sess.execution_mode = ExecutionMode.ORT_SEQUENTIAL + # sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL + # sess.enable_cpu_mem_arena = True + # sess.enable_mem_pattern = True + # sess.add_session_config_entry("session.intra_op.use_xnnpack_threadpool", "1") ########### It's the key code + self.session_height = height + self.session_width = width + if height and width: + sess.add_free_dimension_override_by_name("unet_sample_batch", 2) + sess.add_free_dimension_override_by_name("unet_sample_channels", 4) + sess.add_free_dimension_override_by_name("unet_hidden_batch", 2) + sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77) + sess.add_free_dimension_override_by_name("unet_sample_height", self.session_height) + sess.add_free_dimension_override_by_name("unet_sample_width", self.session_width) + sess.add_free_dimension_override_by_name("unet_time_batch", 1) + providers = [] + if self.provider: + providers.append(self.provider) + else: + providers = get_available_providers() + if "TensorrtExecutionProvider" in providers: + providers.remove("TensorrtExecutionProvider") + try: + self.session = InferenceSession(self.proto.SerializeToString(), providers=providers, sess_options=sess) + except Exception as e: + raise e + # self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options) + # self.io_binding = self.session.io_binding() + + def release_session(self): + self.session = None + import gc + + gc.collect() + return + + def __call__(self, **kwargs): + if self.session is None: + raise Exception("You should call create_session before running model") + + inputs = {k: np.array(v) for k, v in kwargs.items()} + # output_names = self.session.get_outputs() + # for k in inputs: + # self.io_binding.bind_cpu_input(k, inputs[k]) + # for name in output_names: + # self.io_binding.bind_output(name.name) + # self.session.run_with_iobinding(self.io_binding, None) + # return self.io_binding.copy_outputs_to_cpu() + return self.session.run(None, inputs) + + # compatability with diffusers load code + @classmethod + def from_pretrained( + cls, + model_id: Union[str, Path], + subfolder: Optional[Union[str, Path]] = None, + file_name: Optional[str] = None, + provider: Optional[str] = None, + sess_options: Optional["SessionOptions"] = None, + **kwargs: Any, + ) -> Any: # fixme + file_name = file_name or ONNX_WEIGHTS_NAME + + if os.path.isdir(model_id): + model_path = model_id + if subfolder is not None: + model_path = os.path.join(model_path, subfolder) + model_path = os.path.join(model_path, file_name) + + else: + model_path = model_id + + # load model from local directory + if not os.path.isfile(model_path): + raise Exception(f"Model not found: {model_path}") + + # TODO: session options + return cls(str(model_path), provider=provider) diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index ba3ac3dd0cc..9fd118b7822 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -18,9 +18,9 @@ InvalidModelConfigException, ModelConfigFactory, ModelFormat, + ModelRepoVariant, ModelType, ModelVariantType, - ModelRepoVariant, SchedulerPredictionType, ) from .hash import FastModelHash @@ -483,8 +483,8 @@ def get_format(self) -> ModelFormat: def get_repo_variant(self) -> ModelRepoVariant: # get all files ending in .bin or .safetensors - weight_files = list(self.model_path.glob('**/*.safetensors')) - weight_files.extend(list(self.model_path.glob('**/*.bin'))) + weight_files = list(self.model_path.glob("**/*.safetensors")) + weight_files.extend(list(self.model_path.glob("**/*.bin"))) for x in weight_files: if ".fp16" in x.suffixes: return ModelRepoVariant.FP16 @@ -496,6 +496,7 @@ def get_repo_variant(self) -> ModelRepoVariant: return ModelRepoVariant.ONNX return ModelRepoVariant.DEFAULT + class PipelineFolderProbe(FolderProbeBase): def get_base_type(self) -> BaseModelType: with open(self.model_path / "unet" / "config.json", "r") as file: @@ -540,7 +541,6 @@ def get_variant_type(self) -> ModelVariantType: except Exception: pass return ModelVariantType.Normal - class VaeFolderProbe(FolderProbeBase): diff --git a/tests/test_model_probe.py b/tests/test_model_probe.py index 415559a64cd..aacae06a8bb 100644 --- a/tests/test_model_probe.py +++ b/tests/test_model_probe.py @@ -21,9 +21,10 @@ def test_get_base_type(vae_path: str, expected_type: BaseModelType, datadir: Pat base_type = probe.get_base_type() assert base_type == expected_type repo_variant = probe.get_repo_variant() - assert repo_variant == 'default' + assert repo_variant == "default" + def test_repo_variant(datadir: Path): probe = VaeFolderProbe(datadir / "vae" / "taesdxl-fp16") repo_variant = probe.get_repo_variant() - assert repo_variant == 'fp16' + assert repo_variant == "fp16" From 60aa3d4893d42bdd80c33aea40f21eb94a3caef9 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sat, 3 Feb 2024 22:55:09 -0500 Subject: [PATCH 074/100] model loading and conversion implemented for vaes --- invokeai/app/api/dependencies.py | 17 +- .../app/services/config/config_default.py | 21 +- .../model_install/model_install_default.py | 5 +- .../model_records/model_records_base.py | 15 +- .../model_records/model_records_sql.py | 43 +- .../app/services/shared/sqlite/sqlite_util.py | 2 + .../sqlite_migrator/migrations/migration_6.py | 44 + invokeai/backend/install/install_helper.py | 11 +- invokeai/backend/model_manager/__init__.py | 4 + invokeai/backend/model_manager/config.py | 10 +- .../convert_ckpt_to_diffusers.py | 1744 +++++++++++++++++ .../backend/model_manager/load/__init__.py | 35 + .../load/convert_cache/__init__.py | 4 + .../load/convert_cache/convert_cache_base.py | 28 + .../convert_cache/convert_cache_default.py | 64 + .../backend/model_manager/load/load_base.py | 72 +- .../model_manager/load/load_default.py | 23 +- .../load/model_cache/__init__.py | 5 + .../model_cache_base.py} | 54 +- .../model_cache_default.py} | 202 +- .../load/model_cache/model_locker.py | 59 + .../load/model_loaders/__init__.py | 3 + .../model_manager/load/model_loaders/vae.py | 83 + .../backend/model_manager/load/model_util.py | 3 + .../model_manager/load/ram_cache/__init__.py | 0 invokeai/backend/model_manager/load/vae.py | 31 - invokeai/backend/util/__init__.py | 12 +- invokeai/backend/util/devices.py | 5 +- invokeai/backend/util/util.py | 14 + 29 files changed, 2379 insertions(+), 234 deletions(-) create mode 100644 invokeai/app/services/shared/sqlite_migrator/migrations/migration_6.py create mode 100644 invokeai/backend/model_manager/convert_ckpt_to_diffusers.py create mode 100644 invokeai/backend/model_manager/load/convert_cache/__init__.py create mode 100644 invokeai/backend/model_manager/load/convert_cache/convert_cache_base.py create mode 100644 invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py create mode 100644 invokeai/backend/model_manager/load/model_cache/__init__.py rename invokeai/backend/model_manager/load/{ram_cache/ram_cache_base.py => model_cache/model_cache_base.py} (77%) rename invokeai/backend/model_manager/load/{ram_cache/ram_cache_default.py => model_cache/model_cache_default.py} (63%) create mode 100644 invokeai/backend/model_manager/load/model_cache/model_locker.py create mode 100644 invokeai/backend/model_manager/load/model_loaders/__init__.py create mode 100644 invokeai/backend/model_manager/load/model_loaders/vae.py delete mode 100644 invokeai/backend/model_manager/load/ram_cache/__init__.py delete mode 100644 invokeai/backend/model_manager/load/vae.py diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 0f2a92b5c8e..dcb8d219971 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -8,6 +8,8 @@ from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache from invokeai.app.services.shared.sqlite.sqlite_util import init_db +from invokeai.backend.model_manager.load import AnyModelLoader, ModelConvertCache +from invokeai.backend.model_manager.load.model_cache import ModelCache from invokeai.backend.model_manager.metadata import ModelMetadataStore from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData from invokeai.backend.util.logging import InvokeAILogger @@ -98,15 +100,26 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger ) model_manager = ModelManagerService(config, logger) model_record_service = ModelRecordServiceSQL(db=db) + model_loader = AnyModelLoader( + app_config=config, + logger=logger, + ram_cache=ModelCache( + max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger + ), + convert_cache=ModelConvertCache( + cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size + ), + ) + model_record_service = ModelRecordServiceSQL(db=db, loader=model_loader) download_queue_service = DownloadQueueService(event_bus=events) - metadata_store = ModelMetadataStore(db=db) model_install_service = ModelInstallService( app_config=config, record_store=model_record_service, download_queue=download_queue_service, - metadata_store=metadata_store, + metadata_store=ModelMetadataStore(db=db), event_bus=events, ) + model_manager = ModelManagerService(config, logger) # TO DO: legacy model manager v1. Remove names = SimpleNameService() performance_statistics = InvocationStatsService() processor = DefaultInvocationProcessor() diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 132afc22722..b161ea18d61 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -237,6 +237,7 @@ class InvokeAIAppConfig(InvokeAISettings): autoimport_dir : Path = Field(default=Path('autoimport'), description='Path to a directory of models files to be imported on startup.', json_schema_extra=Categories.Paths) conf_path : Path = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Paths) models_dir : Path = Field(default=Path('models'), description='Path to the models directory', json_schema_extra=Categories.Paths) + convert_cache_dir : Path = Field(default=Path('models/.cache'), description='Path to the converted models cache directory', json_schema_extra=Categories.Paths) legacy_conf_dir : Path = Field(default=Path('configs/stable-diffusion'), description='Path to directory of legacy checkpoint config files', json_schema_extra=Categories.Paths) db_dir : Path = Field(default=Path('databases'), description='Path to InvokeAI databases directory', json_schema_extra=Categories.Paths) outdir : Path = Field(default=Path('outputs'), description='Default folder for output images', json_schema_extra=Categories.Paths) @@ -262,6 +263,8 @@ class InvokeAIAppConfig(InvokeAISettings): # CACHE ram : float = Field(default=7.5, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", json_schema_extra=Categories.ModelCache, ) vram : float = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", json_schema_extra=Categories.ModelCache, ) + convert_cache : float = Field(default=10.0, ge=0, description="Maximum size of on-disk converted models cache (GB)", json_schema_extra=Categories.ModelCache) + lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", json_schema_extra=Categories.ModelCache, ) log_memory_usage : bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.", json_schema_extra=Categories.ModelCache) @@ -404,6 +407,11 @@ def models_path(self) -> Path: """Path to the models directory.""" return self._resolve(self.models_dir) + @property + def models_convert_cache_path(self) -> Path: + """Path to the converted cache models directory.""" + return self._resolve(self.convert_cache_dir) + @property def custom_nodes_path(self) -> Path: """Path to the custom nodes directory.""" @@ -433,15 +441,20 @@ def invisible_watermark(self) -> bool: return True @property - def ram_cache_size(self) -> Union[Literal["auto"], float]: - """Return the ram cache size using the legacy or modern setting.""" + def ram_cache_size(self) -> float: + """Return the ram cache size using the legacy or modern setting (GB).""" return self.max_cache_size or self.ram @property - def vram_cache_size(self) -> Union[Literal["auto"], float]: - """Return the vram cache size using the legacy or modern setting.""" + def vram_cache_size(self) -> float: + """Return the vram cache size using the legacy or modern setting (GB).""" return self.max_vram_cache_size or self.vram + @property + def convert_cache_size(self) -> float: + """Return the convert cache size on disk (GB).""" + return self.convert_cache + @property def use_cpu(self) -> bool: """Return true if the device is set to CPU or the always_use_cpu flag is set.""" diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 82c667f584f..2b2294bfce4 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -145,7 +145,7 @@ def register_path( ) -> str: # noqa D102 model_path = Path(model_path) config = config or {} - if config.get("source") is None: + if not config.get("source"): config["source"] = model_path.resolve().as_posix() return self._register(model_path, config) @@ -156,7 +156,7 @@ def install_path( ) -> str: # noqa D102 model_path = Path(model_path) config = config or {} - if config.get("source") is None: + if not config.get("source"): config["source"] = model_path.resolve().as_posix() info: AnyModelConfig = self._probe_model(Path(model_path), config) @@ -300,6 +300,7 @@ def _install_next_item(self) -> None: job.total_bytes = self._stat_size(job.local_path) job.bytes = job.total_bytes self._signal_job_running(job) + job.config_in["source"] = str(job.source) if job.inplace: key = self.register_path(job.local_path, job.config_in) else: diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index 57597570cde..31cfecb4ec8 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -11,7 +11,7 @@ from pydantic import BaseModel, Field from invokeai.app.services.shared.pagination import PaginatedResults -from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType +from invokeai.backend.model_manager import LoadedModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore @@ -102,6 +102,19 @@ def get_model(self, key: str) -> AnyModelConfig: """ pass + @abstractmethod + def load_model(self, key: str, submodel_type: Optional[SubModelType]) -> LoadedModel: + """ + Load the indicated model into memory and return a LoadedModel object. + + :param key: Key of model config to be fetched. + :param submodel_type: For main (pipeline models), the submodel to fetch + + Exceptions: UnknownModelException -- model with this key not known + NotImplementedException -- a model loader was not provided at initialization time + """ + pass + @property @abstractmethod def metadata_store(self) -> ModelMetadataStore: diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index 4512da5d413..eee867ccb46 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -42,6 +42,7 @@ import json import sqlite3 +import time from math import ceil from pathlib import Path from typing import Any, Dict, List, Optional, Set, Tuple, Union @@ -53,8 +54,10 @@ ModelConfigFactory, ModelFormat, ModelType, + SubModelType, ) from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException +from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel from ..shared.sqlite.sqlite_database import SqliteDatabase from .model_records_base import ( @@ -69,16 +72,17 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): """Implementation of the ModelConfigStore ABC using a SQL database.""" - def __init__(self, db: SqliteDatabase): + def __init__(self, db: SqliteDatabase, loader: Optional[AnyModelLoader]=None): """ Initialize a new object from preexisting sqlite3 connection and threading lock objects. - :param conn: sqlite3 connection object - :param lock: threading Lock object + :param db: Sqlite connection object + :param loader: Initialized model loader object (optional) """ super().__init__() self._db = db - self._cursor = self._db.conn.cursor() + self._cursor = db.conn.cursor() + self._loader = loader @property def db(self) -> SqliteDatabase: @@ -199,7 +203,7 @@ def get_model(self, key: str) -> AnyModelConfig: with self._db.lock: self._cursor.execute( """--sql - SELECT config FROM model_config + SELECT config, strftime('%s',updated_at) FROM model_config WHERE id=?; """, (key,), @@ -207,9 +211,24 @@ def get_model(self, key: str) -> AnyModelConfig: rows = self._cursor.fetchone() if not rows: raise UnknownModelException("model not found") - model = ModelConfigFactory.make_config(json.loads(rows[0])) + model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1]) return model + def load_model(self, key: str, submodel_type: Optional[SubModelType]) -> LoadedModel: + """ + Load the indicated model into memory and return a LoadedModel object. + + :param key: Key of model config to be fetched. + :param submodel_type: For main (pipeline models), the submodel to fetch. + + Exceptions: UnknownModelException -- model with this key not known + NotImplementedException -- a model loader was not provided at initialization time + """ + if not self._loader: + raise NotImplementedError(f"Class {self.__class__} was not initialized with a model loader") + model_config = self.get_model(key) + return self._loader.load_model(model_config, submodel_type) + def exists(self, key: str) -> bool: """ Return True if a model with the indicated key exists in the databse. @@ -265,12 +284,12 @@ def search_by_attr( with self._db.lock: self._cursor.execute( f"""--sql - select config FROM model_config + select config, strftime('%s',updated_at) FROM model_config {where}; """, tuple(bindings), ) - results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()] + results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()] return results def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]: @@ -279,12 +298,12 @@ def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]: with self._db.lock: self._cursor.execute( """--sql - SELECT config FROM model_config + SELECT config, strftime('%s',updated_at) FROM model_config WHERE path=?; """, (str(path),), ) - results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()] + results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()] return results def search_by_hash(self, hash: str) -> List[AnyModelConfig]: @@ -293,12 +312,12 @@ def search_by_hash(self, hash: str) -> List[AnyModelConfig]: with self._db.lock: self._cursor.execute( """--sql - SELECT config FROM model_config + SELECT config, strftime('%s',updated_at) FROM model_config WHERE original_hash=?; """, (hash,), ) - results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()] + results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()] return results @property diff --git a/invokeai/app/services/shared/sqlite/sqlite_util.py b/invokeai/app/services/shared/sqlite/sqlite_util.py index 6079b3f08d7..681886eacd3 100644 --- a/invokeai/app/services/shared/sqlite/sqlite_util.py +++ b/invokeai/app/services/shared/sqlite/sqlite_util.py @@ -8,6 +8,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_3 import build_migration_3 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_5 import build_migration_5 +from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import build_migration_6 from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator @@ -33,6 +34,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto migrator.register_migration(build_migration_3(app_config=config, logger=logger)) migrator.register_migration(build_migration_4()) migrator.register_migration(build_migration_5()) + migrator.register_migration(build_migration_6()) migrator.run_migrations() return db diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_6.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_6.py new file mode 100644 index 00000000000..e72878f726f --- /dev/null +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_6.py @@ -0,0 +1,44 @@ +import sqlite3 +from logging import Logger + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration + +class Migration6Callback: + + def __call__(self, cursor: sqlite3.Cursor) -> None: + self._recreate_model_triggers(cursor) + + def _recreate_model_triggers(self, cursor: sqlite3.Cursor) -> None: + """ + Adds the timestamp trigger to the model_config table. + + This trigger was inadvertently dropped in earlier migration scripts. + """ + + cursor.execute( + """--sql + CREATE TRIGGER IF NOT EXISTS model_config_updated_at + AFTER UPDATE + ON model_config FOR EACH ROW + BEGIN + UPDATE model_config SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW') + WHERE id = old.id; + END; + """ + ) + +def build_migration_6() -> Migration: + """ + Build the migration from database version 5 to 6. + + This migration does the following: + - Adds the model_config_updated_at trigger if it does not exist + """ + migration_6 = Migration( + from_version=5, + to_version=6, + callback=Migration6Callback(), + ) + + return migration_6 diff --git a/invokeai/backend/install/install_helper.py b/invokeai/backend/install/install_helper.py index e54be527d95..8c03d2ccf84 100644 --- a/invokeai/backend/install/install_helper.py +++ b/invokeai/backend/install/install_helper.py @@ -98,11 +98,13 @@ def __init__(self) -> None: super().__init__() self._bars: Dict[str, tqdm] = {} self._last: Dict[str, int] = {} + self._logger = InvokeAILogger.get_logger(__name__) def dispatch(self, event_name: str, payload: Any) -> None: """Dispatch an event by appending it to self.events.""" + data = payload["data"] + source = data["source"] if payload["event"] == "model_install_downloading": - data = payload["data"] dest = data["local_path"] total_bytes = data["total_bytes"] bytes = data["bytes"] @@ -111,7 +113,12 @@ def dispatch(self, event_name: str, payload: Any) -> None: self._last[dest] = 0 self._bars[dest].update(bytes - self._last[dest]) self._last[dest] = bytes - + elif payload["event"] == "model_install_completed": + self._logger.info(f"{source}: installed successfully.") + elif payload["event"] == "model_install_error": + self._logger.warning(f"{source}: installation failed with error {data['error']}") + elif payload["event"] == "model_install_cancelled": + self._logger.warning(f"{source}: installation cancelled") class InstallHelper(object): """Capture information stored jointly in INITIAL_MODELS.yaml and the installed models db.""" diff --git a/invokeai/backend/model_manager/__init__.py b/invokeai/backend/model_manager/__init__.py index 0f16852c934..f3c84cd01f7 100644 --- a/invokeai/backend/model_manager/__init__.py +++ b/invokeai/backend/model_manager/__init__.py @@ -1,6 +1,7 @@ """Re-export frequently-used symbols from the Model Manager backend.""" from .config import ( + AnyModel, AnyModelConfig, BaseModelType, InvalidModelConfigException, @@ -14,12 +15,15 @@ ) from .probe import ModelProbe from .search import ModelSearch +from .load import LoadedModel __all__ = [ + "AnyModel", "AnyModelConfig", "BaseModelType", "ModelRepoVariant", "InvalidModelConfigException", + "LoadedModel", "ModelConfigFactory", "ModelFormat", "ModelProbe", diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 338669c873a..796ccbacde0 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -19,12 +19,15 @@ Validation errors will raise an InvalidModelConfigException error. """ +import time +import torch from enum import Enum from typing import Literal, Optional, Type, Union from pydantic import BaseModel, ConfigDict, Field, TypeAdapter +from diffusers import ModelMixin from typing_extensions import Annotated, Any, Dict - +from .onnx_runtime import IAIOnnxRuntimeModel class InvalidModelConfigException(Exception): """Exception for when config parser doesn't recognized this combination of model type and format.""" @@ -127,6 +130,7 @@ class ModelConfigBase(BaseModel): ) # if model is converted or otherwise modified, this will hold updated hash description: Optional[str] = Field(default=None) source: Optional[str] = Field(description="Model download source (URL or repo_id)", default=None) + last_modified: Optional[float] = Field(description="Timestamp for modification time", default_factory=time.time) model_config = ConfigDict( use_enum_values=False, @@ -280,6 +284,7 @@ class T2IConfig(ModelConfigBase): ] AnyModelConfigValidator = TypeAdapter(AnyModelConfig) +AnyModel = Union[ModelMixin, torch.nn.Module, IAIOnnxRuntimeModel] # IMPLEMENTATION NOTE: # The preferred alternative to the above is a discriminated Union as shown @@ -312,6 +317,7 @@ def make_config( model_data: Union[dict, AnyModelConfig], key: Optional[str] = None, dest_class: Optional[Type] = None, + timestamp: Optional[float] = None ) -> AnyModelConfig: """ Return the appropriate config object from raw dict values. @@ -330,4 +336,6 @@ def make_config( model = AnyModelConfigValidator.validate_python(model_data) if key: model.key = key + if timestamp: + model.last_modified = timestamp return model diff --git a/invokeai/backend/model_manager/convert_ckpt_to_diffusers.py b/invokeai/backend/model_manager/convert_ckpt_to_diffusers.py new file mode 100644 index 00000000000..9d6fc4841f2 --- /dev/null +++ b/invokeai/backend/model_manager/convert_ckpt_to_diffusers.py @@ -0,0 +1,1744 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Adapted for use in InvokeAI by Lincoln Stein, July 2023 +# +""" Conversion script for the Stable Diffusion checkpoints.""" + +import re +from contextlib import nullcontext +from io import BytesIO +from pathlib import Path +from typing import Optional, Union + +import requests +import torch +from diffusers.models import AutoencoderKL, ControlNetModel, PriorTransformer, UNet2DConditionModel +from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel +from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer +from diffusers.schedulers import ( + DDIMScheduler, + DDPMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + UnCLIPScheduler, +) +from diffusers.utils import is_accelerate_available +from diffusers.utils.import_utils import BACKENDS_MAPPING +from picklescan.scanner import scan_file_path +from transformers import ( + AutoFeatureExtractor, + BertTokenizerFast, + CLIPImageProcessor, + CLIPTextConfig, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionConfig, + CLIPVisionModelWithProjection, +) + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend.util.logging import InvokeAILogger +from invokeai.backend.model_manager import BaseModelType, ModelVariantType + +try: + from omegaconf import OmegaConf + from omegaconf.dictconfig import DictConfig +except ImportError: + raise ImportError( + "OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`." + ) + +if is_accelerate_available(): + from accelerate import init_empty_weights + from accelerate.utils import set_module_tensor_to_device + +logger = InvokeAILogger.get_logger(__name__) +CONVERT_MODEL_ROOT = InvokeAIAppConfig.get_config().models_path / "core/convert" + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("q.weight", "to_q.weight") + new_item = new_item.replace("q.bias", "to_q.bias") + + new_item = new_item.replace("k.weight", "to_k.weight") + new_item = new_item.replace("k.bias", "to_k.bias") + + new_item = new_item.replace("v.weight", "to_v.weight") + new_item = new_item.replace("v.bias", "to_v.bias") + + new_item = new_item.replace("proj_out.weight", "to_out.0.weight") + new_item = new_item.replace("proj_out.bias", "to_out.0.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits + attention layers, and takes into account additional replacements that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path) + shape = old_checkpoint[path["old"]].shape + if is_attn_weight and len(shape) == 3: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + elif is_attn_weight and len(shape) == 4: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + if controlnet: + unet_params = original_config.model.params.control_stage_config.params + else: + if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None: + unet_params = original_config.model.params.unet_config.params + else: + unet_params = original_config.model.params.network_config.params + + vae_params = original_config.model.params.first_stage_config.params.ddconfig + + block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for _i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + if unet_params.transformer_depth is not None: + transformer_layers_per_block = ( + unet_params.transformer_depth + if isinstance(unet_params.transformer_depth, int) + else list(unet_params.transformer_depth) + ) + else: + transformer_layers_per_block = 1 + + vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) + + head_dim = unet_params.num_heads if "num_heads" in unet_params else None + use_linear_projection = ( + unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False + ) + if use_linear_projection: + # stable diffusion 2-base-512 and 2-768 + if head_dim is None: + head_dim_mult = unet_params.model_channels // unet_params.num_head_channels + head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)] + + class_embed_type = None + addition_embed_type = None + addition_time_embed_dim = None + projection_class_embeddings_input_dim = None + context_dim = None + + if unet_params.context_dim is not None: + context_dim = ( + unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0] + ) + + if "num_classes" in unet_params: + if unet_params.num_classes == "sequential": + if context_dim in [2048, 1280]: + # SDXL + addition_embed_type = "text_time" + addition_time_embed_dim = 256 + else: + class_embed_type = "projection" + assert "adm_in_channels" in unet_params + projection_class_embeddings_input_dim = unet_params.adm_in_channels + else: + raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}") + + config = { + "sample_size": image_size // vae_scale_factor, + "in_channels": unet_params.in_channels, + "down_block_types": tuple(down_block_types), + "block_out_channels": tuple(block_out_channels), + "layers_per_block": unet_params.num_res_blocks, + "cross_attention_dim": context_dim, + "attention_head_dim": head_dim, + "use_linear_projection": use_linear_projection, + "class_embed_type": class_embed_type, + "addition_embed_type": addition_embed_type, + "addition_time_embed_dim": addition_time_embed_dim, + "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, + "transformer_layers_per_block": transformer_layers_per_block, + } + + if controlnet: + config["conditioning_channels"] = unet_params.hint_channels + else: + config["out_channels"] = unet_params.out_channels + config["up_block_types"] = tuple(up_block_types) + + return config + + +def create_vae_diffusers_config(original_config, image_size: int): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + vae_params = original_config.model.params.first_stage_config.params.ddconfig + _ = original_config.model.params.first_stage_config.params.embed_dim + + block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) + + config = { + "sample_size": image_size, + "in_channels": vae_params.in_channels, + "out_channels": vae_params.out_ch, + "down_block_types": tuple(down_block_types), + "up_block_types": tuple(up_block_types), + "block_out_channels": tuple(block_out_channels), + "latent_channels": vae_params.z_channels, + "layers_per_block": vae_params.num_res_blocks, + } + return config + + +def create_diffusers_schedular(original_config): + schedular = DDIMScheduler( + num_train_timesteps=original_config.model.params.timesteps, + beta_start=original_config.model.params.linear_start, + beta_end=original_config.model.params.linear_end, + beta_schedule="scaled_linear", + ) + return schedular + + +def create_ldm_bert_config(original_config): + bert_params = original_config.model.parms.cond_stage_config.params + config = LDMBertConfig( + d_model=bert_params.n_embed, + encoder_layers=bert_params.n_layer, + encoder_ffn_dim=bert_params.n_embed * 4, + ) + return config + + +def convert_ldm_unet_checkpoint( + checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False +): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + if skip_extract_state_dict: + unet_state_dict = checkpoint + else: + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + + if controlnet: + unet_key = "control_model." + else: + unet_key = "model.diffusion_model." + + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.") + logger.warning( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + if sum(k.startswith("model_ema") for k in keys) > 100: + logger.warning( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + if config["class_embed_type"] is None: + # No parameters to port + ... + elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": + new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + else: + raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") + + if config["addition_embed_type"] == "text_time": + new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + if not controlnet: + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + output_block_list = {k: sorted(v) for k, v in output_block_list.items()} + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + if controlnet: + # conditioning embedding + + orig_index = 0 + + new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + + orig_index += 2 + + diffusers_index = 0 + + while diffusers_index < 6: + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + diffusers_index += 1 + orig_index += 2 + + new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + + # down blocks + for i in range(num_input_blocks): + new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight") + new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias") + + # mid block + new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight") + new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias") + + return new_checkpoint + + +def convert_ldm_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + vae_state_dict = {} + keys = list(checkpoint.keys()) + vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else "" + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + + +def convert_ldm_bert_checkpoint(checkpoint, config): + def _copy_attn_layer(hf_attn_layer, pt_attn_layer): + hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight + hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight + hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight + + hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight + hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias + + def _copy_linear(hf_linear, pt_linear): + hf_linear.weight = pt_linear.weight + hf_linear.bias = pt_linear.bias + + def _copy_layer(hf_layer, pt_layer): + # copy layer norms + _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0]) + _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0]) + + # copy attn + _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1]) + + # copy MLP + pt_mlp = pt_layer[1][1] + _copy_linear(hf_layer.fc1, pt_mlp.net[0][0]) + _copy_linear(hf_layer.fc2, pt_mlp.net[2]) + + def _copy_layers(hf_layers, pt_layers): + for i, hf_layer in enumerate(hf_layers): + if i != 0: + i += i + pt_layer = pt_layers[i : i + 2] + _copy_layer(hf_layer, pt_layer) + + hf_model = LDMBertModel(config).eval() + + # copy embeds + hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight + hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight + + # copy layer norm + _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) + + # copy hidden layers + _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers) + + _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits) + + return hf_model + + +def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None): + if text_encoder is None: + config = CLIPTextConfig.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + text_model = CLIPTextModel(config) + + keys = list(checkpoint.keys()) + + text_model_dict = {} + + remove_prefixes = ["cond_stage_model.transformer", "conditioner.embedders.0.transformer"] + + for key in keys: + for prefix in remove_prefixes: + if key.startswith(prefix): + text_model_dict[key[len(prefix + ".") :]] = checkpoint[key] + + if is_accelerate_available(): + for param_name, param in text_model_dict.items(): + set_module_tensor_to_device(text_model, param_name, "cpu", value=param) + else: + text_model.load_state_dict(text_model_dict) + + return text_model + + +textenc_conversion_lst = [ + ("positional_embedding", "text_model.embeddings.position_embedding.weight"), + ("token_embedding.weight", "text_model.embeddings.token_embedding.weight"), + ("ln_final.weight", "text_model.final_layer_norm.weight"), + ("ln_final.bias", "text_model.final_layer_norm.bias"), + ("text_projection", "text_projection.weight"), +] +textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst} + +textenc_transformer_conversion_lst = [ + # (stable-diffusion, HF Diffusers) + ("resblocks.", "text_model.encoder.layers."), + ("ln_1", "layer_norm1"), + ("ln_2", "layer_norm2"), + (".c_fc.", ".fc1."), + (".c_proj.", ".fc2."), + (".attn", ".self_attn"), + ("ln_final.", "transformer.text_model.final_layer_norm."), + ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), + ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), +] +protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst} +textenc_pattern = re.compile("|".join(protected.keys())) + + +def convert_paint_by_example_checkpoint(checkpoint): + config = CLIPVisionConfig.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") + model = PaintByExampleImageEncoder(config) + + keys = list(checkpoint.keys()) + + text_model_dict = {} + + for key in keys: + if key.startswith("cond_stage_model.transformer"): + text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] + + # load clip vision + model.model.load_state_dict(text_model_dict) + + # load mapper + keys_mapper = { + k[len("cond_stage_model.mapper.res") :]: v + for k, v in checkpoint.items() + if k.startswith("cond_stage_model.mapper") + } + + MAPPING = { + "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"], + "attn.c_proj": ["attn1.to_out.0"], + "ln_1": ["norm1"], + "ln_2": ["norm3"], + "mlp.c_fc": ["ff.net.0.proj"], + "mlp.c_proj": ["ff.net.2"], + } + + mapped_weights = {} + for key, value in keys_mapper.items(): + prefix = key[: len("blocks.i")] + suffix = key.split(prefix)[-1].split(".")[-1] + name = key.split(prefix)[-1].split(suffix)[0][1:-1] + mapped_names = MAPPING[name] + + num_splits = len(mapped_names) + for i, mapped_name in enumerate(mapped_names): + new_name = ".".join([prefix, mapped_name, suffix]) + shape = value.shape[0] // num_splits + mapped_weights[new_name] = value[i * shape : (i + 1) * shape] + + model.mapper.load_state_dict(mapped_weights) + + # load final layer norm + model.final_layer_norm.load_state_dict( + { + "bias": checkpoint["cond_stage_model.final_ln.bias"], + "weight": checkpoint["cond_stage_model.final_ln.weight"], + } + ) + + # load final proj + model.proj_out.load_state_dict( + { + "bias": checkpoint["proj_out.bias"], + "weight": checkpoint["proj_out.weight"], + } + ) + + # load uncond vector + model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"]) + return model + + +def convert_open_clip_checkpoint( + checkpoint, config_name, prefix="cond_stage_model.model.", has_projection=False, **config_kwargs +): + # text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder") + # text_model = CLIPTextModelWithProjection.from_pretrained( + # "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280 + # ) + config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs) + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + text_model = CLIPTextModelWithProjection(config) if has_projection else CLIPTextModel(config) + + keys = list(checkpoint.keys()) + + keys_to_ignore = [] + if config_name == "stabilityai/stable-diffusion-2" and config.num_hidden_layers == 23: + # make sure to remove all keys > 22 + keys_to_ignore += [k for k in keys if k.startswith("cond_stage_model.model.transformer.resblocks.23")] + keys_to_ignore += ["cond_stage_model.model.text_projection"] + + text_model_dict = {} + + if prefix + "text_projection" in checkpoint: + d_model = int(checkpoint[prefix + "text_projection"].shape[0]) + else: + d_model = 1024 + + text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids") + + for key in keys: + if key in keys_to_ignore: + continue + if key[len(prefix) :] in textenc_conversion_map: + if key.endswith("text_projection"): + value = checkpoint[key].T.contiguous() + else: + value = checkpoint[key] + + text_model_dict[textenc_conversion_map[key[len(prefix) :]]] = value + + if key.startswith(prefix + "transformer."): + new_key = key[len(prefix + "transformer.") :] + if new_key.endswith(".in_proj_weight"): + new_key = new_key[: -len(".in_proj_weight")] + new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :] + text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :] + text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :] + elif new_key.endswith(".in_proj_bias"): + new_key = new_key[: -len(".in_proj_bias")] + new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model] + text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2] + text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :] + else: + new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + + text_model_dict[new_key] = checkpoint[key] + + if is_accelerate_available(): + for param_name, param in text_model_dict.items(): + set_module_tensor_to_device(text_model, param_name, "cpu", value=param) + else: + text_model.load_state_dict(text_model_dict) + + return text_model + + +def stable_unclip_image_encoder(original_config): + """ + Returns the image processor and clip image encoder for the img2img unclip pipeline. + + We currently know of two types of stable unclip models which separately use the clip and the openclip image + encoders. + """ + + image_embedder_config = original_config.model.params.embedder_config + + sd_clip_image_embedder_class = image_embedder_config.target + sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1] + + if sd_clip_image_embedder_class == "ClipImageEmbedder": + clip_model_name = image_embedder_config.params.model + + if clip_model_name == "ViT-L/14": + feature_extractor = CLIPImageProcessor() + image_encoder = CLIPVisionModelWithProjection.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") + else: + raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}") + + elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder": + feature_extractor = CLIPImageProcessor() + # InvokeAI doesn't use CLIPVisionModelWithProjection so it isn't in the core - if this code is hit a download will occur + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + CONVERT_MODEL_ROOT / "CLIP-ViT-H-14-laion2B-s32B-b79K" + ) + else: + raise NotImplementedError( + f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}" + ) + + return feature_extractor, image_encoder + + +def stable_unclip_image_noising_components( + original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None +): + """ + Returns the noising components for the img2img and txt2img unclip pipelines. + + Converts the stability noise augmentor into + 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats + 2. a `DDPMScheduler` for holding the noise schedule + + If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided. + """ + noise_aug_config = original_config.model.params.noise_aug_config + noise_aug_class = noise_aug_config.target + noise_aug_class = noise_aug_class.split(".")[-1] + + if noise_aug_class == "CLIPEmbeddingNoiseAugmentation": + noise_aug_config = noise_aug_config.params + embedding_dim = noise_aug_config.timestep_dim + max_noise_level = noise_aug_config.noise_schedule_config.timesteps + beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule + + image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim) + image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule) + + if "clip_stats_path" in noise_aug_config: + if clip_stats_path is None: + raise ValueError("This stable unclip config requires a `clip_stats_path`") + + clip_mean, clip_std = torch.load(clip_stats_path, map_location=device) + clip_mean = clip_mean[None, :] + clip_std = clip_std[None, :] + + clip_stats_state_dict = { + "mean": clip_mean, + "std": clip_std, + } + + image_normalizer.load_state_dict(clip_stats_state_dict) + else: + raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}") + + return image_normalizer, image_noising_scheduler + + +def convert_controlnet_checkpoint( + checkpoint, + original_config, + checkpoint_path, + image_size, + upcast_attention, + extract_ema, + use_linear_projection=None, + cross_attention_dim=None, + precision: Optional[torch.dtype] = None, +): + ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True) + ctrlnet_config["upcast_attention"] = upcast_attention + + ctrlnet_config.pop("sample_size") + original_config = ctrlnet_config.copy() + + ctrlnet_config.pop("addition_embed_type") + ctrlnet_config.pop("addition_time_embed_dim") + ctrlnet_config.pop("transformer_layers_per_block") + + if use_linear_projection is not None: + ctrlnet_config["use_linear_projection"] = use_linear_projection + + if cross_attention_dim is not None: + ctrlnet_config["cross_attention_dim"] = cross_attention_dim + + controlnet = ControlNetModel(**ctrlnet_config) + + # Some controlnet ckpt files are distributed independently from the rest of the + # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ + if "time_embed.0.weight" in checkpoint: + skip_extract_state_dict = True + else: + skip_extract_state_dict = False + + converted_ctrl_checkpoint = convert_ldm_unet_checkpoint( + checkpoint, + original_config, + path=checkpoint_path, + extract_ema=extract_ema, + controlnet=True, + skip_extract_state_dict=skip_extract_state_dict, + ) + + controlnet.load_state_dict(converted_ctrl_checkpoint) + + return controlnet.to(precision) + + +def download_from_original_stable_diffusion_ckpt( + checkpoint_path: str, + model_version: BaseModelType, + model_variant: ModelVariantType, + original_config_file: str = None, + image_size: Optional[int] = None, + prediction_type: str = None, + model_type: str = None, + extract_ema: bool = False, + precision: Optional[torch.dtype] = None, + scheduler_type: str = "pndm", + num_in_channels: Optional[int] = None, + upcast_attention: Optional[bool] = None, + device: str = None, + from_safetensors: bool = False, + stable_unclip: Optional[str] = None, + stable_unclip_prior: Optional[str] = None, + clip_stats_path: Optional[str] = None, + controlnet: Optional[bool] = None, + load_safety_checker: bool = True, + pipeline_class: DiffusionPipeline = None, + local_files_only=False, + vae_path=None, + text_encoder=None, + tokenizer=None, + scan_needed: bool = True, +) -> DiffusionPipeline: + """ + Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml` + config file. + + Although many of the arguments can be automatically inferred, some of these rely on brittle checks against the + global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is + recommended that you override the default values and/or supply an `original_config_file` wherever possible. + + Args: + checkpoint_path (`str`): Path to `.ckpt` file. + original_config_file (`str`): + Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically + inferred by looking for a key that only exists in SD2.0 models. + image_size (`int`, *optional*, defaults to 512): + The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2 + Base. Use 768 for Stable Diffusion v2. + prediction_type (`str`, *optional*): + The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion v1.X and Stable + Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2. + num_in_channels (`int`, *optional*, defaults to None): + The number of input channels. If `None`, it will be automatically inferred. + scheduler_type (`str`, *optional*, defaults to 'pndm'): + Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm", + "ddim"]`. + model_type (`str`, *optional*, defaults to `None`): + The pipeline type. `None` to automatically infer, or one of `["FrozenOpenCLIPEmbedder", + "FrozenCLIPEmbedder", "PaintByExample"]`. + is_img2img (`bool`, *optional*, defaults to `False`): + Whether the model should be loaded as an img2img pipeline. + extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for + checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to + `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for + inference. Non-EMA weights are usually better to continue fine-tuning. + upcast_attention (`bool`, *optional*, defaults to `None`): + Whether the attention computation should always be upcasted. This is necessary when running stable + diffusion 2.1. + device (`str`, *optional*, defaults to `None`): + The device to use. Pass `None` to determine automatically. + from_safetensors (`str`, *optional*, defaults to `False`): + If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch. + load_safety_checker (`bool`, *optional*, defaults to `True`): + Whether to load the safety checker or not. Defaults to `True`. + pipeline_class (`str`, *optional*, defaults to `None`): + The pipeline class to use. Pass `None` to determine automatically. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + text_encoder (`CLIPTextModel`, *optional*, defaults to `None`): + An instance of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) + to use, specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) + variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed. + tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`): + An instance of + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer) + to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by itself, if + needed. + precision (`torch.dtype`, *optional*, defauts to `None`): + If not provided the precision will be set to the precision of the original file. + return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file. + """ + + # import pipelines here to avoid circular import error when using from_single_file method + from diffusers import ( + LDMTextToImagePipeline, + PaintByExamplePipeline, + StableDiffusionControlNetPipeline, + StableDiffusionInpaintPipeline, + StableDiffusionPipeline, + StableDiffusionXLImg2ImgPipeline, + StableDiffusionXLPipeline, + StableUnCLIPImg2ImgPipeline, + StableUnCLIPPipeline, + ) + + if pipeline_class is None: + pipeline_class = StableDiffusionPipeline if not controlnet else StableDiffusionControlNetPipeline + + if prediction_type == "v-prediction": + prediction_type = "v_prediction" + + if from_safetensors: + from safetensors.torch import load_file as safe_load + + checkpoint = safe_load(checkpoint_path, device="cpu") + else: + if scan_needed: + # scan model + scan_result = scan_file_path(checkpoint_path) + if scan_result.infected_files != 0: + raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.") + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + checkpoint = torch.load(checkpoint_path, map_location=device) + else: + checkpoint = torch.load(checkpoint_path, map_location=device) + + # Sometimes models don't have the global_step item + if "global_step" in checkpoint: + global_step = checkpoint["global_step"] + else: + logger.debug("global_step key not found in model") + global_step = None + + # NOTE: this while loop isn't great but this controlnet checkpoint has one additional + # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 + while "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + logger.debug(f"model_type = {model_type}; original_config_file = {original_config_file}") + + precision_probing_key = "model.diffusion_model.input_blocks.0.0.bias" + logger.debug(f"original checkpoint precision == {checkpoint[precision_probing_key].dtype}") + precision = precision or checkpoint[precision_probing_key].dtype + + if original_config_file is None: + key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" + key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias" + key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias" + + # model_type = "v1" + config_url = ( + "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" + ) + + if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024: + # model_type = "v2" + config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml" + + if global_step == 110000: + # v2.1 needs to upcast attention + upcast_attention = True + elif key_name_sd_xl_base in checkpoint: + # only base xl has two text embedders + config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml" + elif key_name_sd_xl_refiner in checkpoint: + # only refiner xl has embedder and one text embedders + config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml" + + original_config_file = BytesIO(requests.get(config_url).content) + + original_config = OmegaConf.load(original_config_file) + if original_config["model"]["params"].get("use_ema") is not None: + extract_ema = original_config["model"]["params"]["use_ema"] + + if ( + model_version in [BaseModelType.StableDiffusion2, BaseModelType.StableDiffusion1] + and original_config["model"]["params"].get("parameterization") == "v" + ): + prediction_type = "v_prediction" + upcast_attention = True + image_size = 768 if model_version == BaseModelType.StableDiffusion2 else 512 + else: + prediction_type = "epsilon" + upcast_attention = False + image_size = 512 + + # Convert the text model. + if ( + model_type is None + and "cond_stage_config" in original_config.model.params + and original_config.model.params.cond_stage_config is not None + ): + model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] + logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}") + elif model_type is None and original_config.model.params.network_config is not None: + if original_config.model.params.network_config.params.context_dim == 2048: + model_type = "SDXL" + else: + model_type = "SDXL-Refiner" + if image_size is None: + image_size = 1024 + + if num_in_channels is None and pipeline_class == StableDiffusionInpaintPipeline: + num_in_channels = 9 + elif num_in_channels is None: + num_in_channels = 4 + + if "unet_config" in original_config.model.params: + original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels + + if ( + "parameterization" in original_config["model"]["params"] + and original_config["model"]["params"]["parameterization"] == "v" + ): + if prediction_type is None: + # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` + # as it relies on a brittle global step parameter here + prediction_type = "epsilon" if global_step == 875000 else "v_prediction" + if image_size is None: + # NOTE: For stable diffusion 2 base one has to pass `image_size==512` + # as it relies on a brittle global step parameter here + image_size = 512 if global_step == 875000 else 768 + else: + if prediction_type is None: + prediction_type = "epsilon" + if image_size is None: + image_size = 512 + + if controlnet is None and "control_stage_config" in original_config.model.params: + controlnet = convert_controlnet_checkpoint( + checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema + ) + + num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000 + + if model_type in ["SDXL", "SDXL-Refiner"]: + scheduler_dict = { + "beta_schedule": "scaled_linear", + "beta_start": 0.00085, + "beta_end": 0.012, + "interpolation_type": "linear", + "num_train_timesteps": num_train_timesteps, + "prediction_type": "epsilon", + "sample_max_value": 1.0, + "set_alpha_to_one": False, + "skip_prk_steps": True, + "steps_offset": 1, + "timestep_spacing": "leading", + } + scheduler = EulerDiscreteScheduler.from_config(scheduler_dict) + scheduler_type = "euler" + else: + beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02 + beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085 + scheduler = DDIMScheduler( + beta_end=beta_end, + beta_schedule="scaled_linear", + beta_start=beta_start, + num_train_timesteps=num_train_timesteps, + steps_offset=1, + clip_sample=False, + set_alpha_to_one=False, + prediction_type=prediction_type, + ) + # make sure scheduler works correctly with DDIM + scheduler.register_to_config(clip_sample=False) + + if scheduler_type == "pndm": + config = dict(scheduler.config) + config["skip_prk_steps"] = True + scheduler = PNDMScheduler.from_config(config) + elif scheduler_type == "lms": + scheduler = LMSDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "heun": + scheduler = HeunDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "euler": + scheduler = EulerDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "euler-ancestral": + scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "dpm": + scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config) + elif scheduler_type == "ddim": + scheduler = scheduler + else: + raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") + + # Convert the UNet2DConditionModel model. + unet_config = create_unet_diffusers_config(original_config, image_size=image_size) + unet_config["upcast_attention"] = upcast_attention + converted_unet_checkpoint = convert_ldm_unet_checkpoint( + checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema + ) + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + unet = UNet2DConditionModel(**unet_config) + + if is_accelerate_available(): + for param_name, param in converted_unet_checkpoint.items(): + set_module_tensor_to_device(unet, param_name, "cpu", value=param) + else: + unet.load_state_dict(converted_unet_checkpoint) + + # Convert the VAE model. + if vae_path is None: + vae_config = create_vae_diffusers_config(original_config, image_size=image_size) + converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) + + if ( + "model" in original_config + and "params" in original_config.model + and "scale_factor" in original_config.model.params + ): + vae_scaling_factor = original_config.model.params.scale_factor + else: + vae_scaling_factor = 0.18215 # default SD scaling factor + + vae_config["scaling_factor"] = vae_scaling_factor + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + vae = AutoencoderKL(**vae_config) + + if is_accelerate_available(): + for param_name, param in converted_vae_checkpoint.items(): + set_module_tensor_to_device(vae, param_name, "cpu", value=param) + else: + vae.load_state_dict(converted_vae_checkpoint) + else: + vae = AutoencoderKL.from_pretrained(vae_path) + + if model_type == "FrozenOpenCLIPEmbedder": + config_name = "stabilityai/stable-diffusion-2" + config_kwargs = {"subfolder": "text_encoder"} + + text_model = convert_open_clip_checkpoint(checkpoint, config_name, **config_kwargs) + tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "stable-diffusion-2-clip", subfolder="tokenizer") + + if stable_unclip is None: + if controlnet: + pipe = pipeline_class( + vae=vae.to(precision), + text_encoder=text_model.to(precision), + tokenizer=tokenizer, + unet=unet.to(precision), + scheduler=scheduler, + controlnet=controlnet, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + else: + pipe = pipeline_class( + vae=vae.to(precision), + text_encoder=text_model.to(precision), + tokenizer=tokenizer, + unet=unet.to(precision), + scheduler=scheduler, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + else: + image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components( + original_config, clip_stats_path=clip_stats_path, device=device + ) + + if stable_unclip == "img2img": + feature_extractor, image_encoder = stable_unclip_image_encoder(original_config) + + pipe = StableUnCLIPImg2ImgPipeline( + # image encoding components + feature_extractor=feature_extractor, + image_encoder=image_encoder, + # image noising components + image_normalizer=image_normalizer, + image_noising_scheduler=image_noising_scheduler, + # regular denoising components + tokenizer=tokenizer, + text_encoder=text_model.to(precision), + unet=unet.to(precision), + scheduler=scheduler, + # vae + vae=vae, + ) + elif stable_unclip == "txt2img": + if stable_unclip_prior is None or stable_unclip_prior == "karlo": + karlo_model = "kakaobrain/karlo-v1-alpha" + prior = PriorTransformer.from_pretrained(karlo_model, subfolder="prior") + + prior_tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") + prior_text_model = CLIPTextModelWithProjection.from_pretrained( + CONVERT_MODEL_ROOT / "clip-vit-large-patch14" + ) + + prior_scheduler = UnCLIPScheduler.from_pretrained(karlo_model, subfolder="prior_scheduler") + prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config) + else: + raise NotImplementedError(f"unknown prior for stable unclip model: {stable_unclip_prior}") + + pipe = StableUnCLIPPipeline( + # prior components + prior_tokenizer=prior_tokenizer, + prior_text_encoder=prior_text_model, + prior=prior, + prior_scheduler=prior_scheduler, + # image noising components + image_normalizer=image_normalizer, + image_noising_scheduler=image_noising_scheduler, + # regular denoising components + tokenizer=tokenizer, + text_encoder=text_model, + unet=unet, + scheduler=scheduler, + # vae + vae=vae, + ) + else: + raise NotImplementedError(f"unknown `stable_unclip` type: {stable_unclip}") + elif model_type == "PaintByExample": + vision_model = convert_paint_by_example_checkpoint(checkpoint) + tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") + feature_extractor = AutoFeatureExtractor.from_pretrained(CONVERT_MODEL_ROOT / "stable-diffusion-safety-checker") + pipe = PaintByExamplePipeline( + vae=vae, + image_encoder=vision_model, + unet=unet, + scheduler=scheduler, + safety_checker=None, + feature_extractor=feature_extractor, + ) + elif model_type == "FrozenCLIPEmbedder": + text_model = convert_ldm_clip_checkpoint( + checkpoint, local_files_only=local_files_only, text_encoder=text_encoder + ) + tokenizer = ( + CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") + if tokenizer is None + else tokenizer + ) + + if load_safety_checker: + safety_checker = StableDiffusionSafetyChecker.from_pretrained( + CONVERT_MODEL_ROOT / "stable-diffusion-safety-checker" + ) + feature_extractor = AutoFeatureExtractor.from_pretrained( + CONVERT_MODEL_ROOT / "stable-diffusion-safety-checker" + ) + else: + safety_checker = None + feature_extractor = None + + if controlnet: + pipe = pipeline_class( + vae=vae.to(precision), + text_encoder=text_model.to(precision), + tokenizer=tokenizer, + unet=unet.to(precision), + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + else: + pipe = pipeline_class( + vae=vae.to(precision), + text_encoder=text_model.to(precision), + tokenizer=tokenizer, + unet=unet.to(precision), + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + elif model_type in ["SDXL", "SDXL-Refiner"]: + if model_type == "SDXL": + tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") + text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only) + + tokenizer_name = CONVERT_MODEL_ROOT / "CLIP-ViT-bigG-14-laion2B-39B-b160k" + tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_name, pad_token="!") + + config_name = tokenizer_name + config_kwargs = {"projection_dim": 1280} + text_encoder_2 = convert_open_clip_checkpoint( + checkpoint, config_name, prefix="conditioner.embedders.1.model.", has_projection=True, **config_kwargs + ) + + pipe = StableDiffusionXLPipeline( + vae=vae.to(precision), + text_encoder=text_encoder.to(precision), + tokenizer=tokenizer, + text_encoder_2=text_encoder_2.to(precision), + tokenizer_2=tokenizer_2, + unet=unet.to(precision), + scheduler=scheduler, + force_zeros_for_empty_prompt=True, + ) + else: + tokenizer = None + text_encoder = None + tokenizer_name = CONVERT_MODEL_ROOT / "CLIP-ViT-bigG-14-laion2B-39B-b160k" + tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_name, pad_token="!") + + config_name = tokenizer_name + config_kwargs = {"projection_dim": 1280} + text_encoder_2 = convert_open_clip_checkpoint( + checkpoint, config_name, prefix="conditioner.embedders.0.model.", has_projection=True, **config_kwargs + ) + + pipe = StableDiffusionXLImg2ImgPipeline( + vae=vae.to(precision), + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + unet=unet.to(precision), + scheduler=scheduler, + requires_aesthetics_score=True, + force_zeros_for_empty_prompt=False, + ) + else: + text_config = create_ldm_bert_config(original_config) + text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) + tokenizer = BertTokenizerFast.from_pretrained(CONVERT_MODEL_ROOT / "bert-base-uncased") + pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler) + + return pipe + + +def download_controlnet_from_original_ckpt( + checkpoint_path: str, + original_config_file: str, + image_size: int = 512, + extract_ema: bool = False, + precision: Optional[torch.dtype] = None, + num_in_channels: Optional[int] = None, + upcast_attention: Optional[bool] = None, + device: str = None, + from_safetensors: bool = False, + use_linear_projection: Optional[bool] = None, + cross_attention_dim: Optional[bool] = None, + scan_needed: bool = False, +) -> DiffusionPipeline: + + from omegaconf import OmegaConf + + if from_safetensors: + from safetensors import safe_open + + checkpoint = {} + with safe_open(checkpoint_path, framework="pt", device="cpu") as f: + for key in f.keys(): + checkpoint[key] = f.get_tensor(key) + else: + if scan_needed: + # scan model + scan_result = scan_file_path(checkpoint_path) + if scan_result.infected_files != 0: + raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.") + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + checkpoint = torch.load(checkpoint_path, map_location=device) + else: + checkpoint = torch.load(checkpoint_path, map_location=device) + + # NOTE: this while loop isn't great but this controlnet checkpoint has one additional + # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 + while "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + # use original precision + precision_probing_key = "input_blocks.0.0.bias" + ckpt_precision = checkpoint[precision_probing_key].dtype + logger.debug(f"original controlnet precision = {ckpt_precision}") + precision = precision or ckpt_precision + + original_config = OmegaConf.load(original_config_file) + + if num_in_channels is not None: + original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels + + if "control_stage_config" not in original_config.model.params: + raise ValueError("`control_stage_config` not present in original config") + + controlnet = convert_controlnet_checkpoint( + checkpoint, + original_config, + checkpoint_path, + image_size, + upcast_attention, + extract_ema, + use_linear_projection=use_linear_projection, + cross_attention_dim=cross_attention_dim, + ) + + return controlnet.to(precision) + + +def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size: int) -> AutoencoderKL: + vae_config = create_vae_diffusers_config(vae_config, image_size=image_size) + + converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) + + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_checkpoint) + return vae + + +def convert_ckpt_to_diffusers( + checkpoint_path: Union[str, Path], + dump_path: Union[str, Path], + use_safetensors: bool = True, + **kwargs, +): + """ + Takes all the arguments of download_from_original_stable_diffusion_ckpt(), + and in addition a path-like object indicating the location of the desired diffusers + model to be written. + """ + pipe = download_from_original_stable_diffusion_ckpt(checkpoint_path, **kwargs) + + # TO DO: save correct repo variant + pipe.save_pretrained( + dump_path, + safe_serialization=use_safetensors, + ) + + +def convert_controlnet_to_diffusers( + checkpoint_path: Union[str, Path], + dump_path: Union[str, Path], + **kwargs, +): + """ + Takes all the arguments of download_controlnet_from_original_ckpt(), + and in addition a path-like object indicating the location of the desired diffusers + model to be written. + """ + pipe = download_controlnet_from_original_ckpt(checkpoint_path, **kwargs) + + # TO DO: save correct repo variant + pipe.save_pretrained(dump_path, safe_serialization=True) diff --git a/invokeai/backend/model_manager/load/__init__.py b/invokeai/backend/model_manager/load/__init__.py index e69de29bb2d..357677bb7f7 100644 --- a/invokeai/backend/model_manager/load/__init__.py +++ b/invokeai/backend/model_manager/load/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development Team +""" +Init file for the model loader. +""" +from importlib import import_module +from pathlib import Path +from typing import Optional + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend.util.logging import InvokeAILogger +from .load_base import AnyModelLoader, LoadedModel +from .model_cache.model_cache_default import ModelCache +from .convert_cache.convert_cache_default import ModelConvertCache + +# This registers the subclasses that implement loaders of specific model types +loaders = [x.stem for x in Path(Path(__file__).parent,'model_loaders').glob('*.py') if x.stem != '__init__'] +for module in loaders: + print(f'module={module}') + import_module(f"{__package__}.model_loaders.{module}") + +__all__ = ["AnyModelLoader", "LoadedModel"] + + +def get_standalone_loader(app_config: Optional[InvokeAIAppConfig]) -> AnyModelLoader: + app_config = app_config or InvokeAIAppConfig.get_config() + logger = InvokeAILogger.get_logger(config=app_config) + return AnyModelLoader(app_config=app_config, + logger=logger, + ram_cache=ModelCache(logger=logger, + max_cache_size=app_config.ram_cache_size, + max_vram_cache_size=app_config.vram_cache_size + ), + convert_cache=ModelConvertCache(app_config.models_convert_cache_path) + ) + diff --git a/invokeai/backend/model_manager/load/convert_cache/__init__.py b/invokeai/backend/model_manager/load/convert_cache/__init__.py new file mode 100644 index 00000000000..eb3149be329 --- /dev/null +++ b/invokeai/backend/model_manager/load/convert_cache/__init__.py @@ -0,0 +1,4 @@ +from .convert_cache_base import ModelConvertCacheBase +from .convert_cache_default import ModelConvertCache + +__all__ = ['ModelConvertCacheBase', 'ModelConvertCache'] diff --git a/invokeai/backend/model_manager/load/convert_cache/convert_cache_base.py b/invokeai/backend/model_manager/load/convert_cache/convert_cache_base.py new file mode 100644 index 00000000000..25263f96aaa --- /dev/null +++ b/invokeai/backend/model_manager/load/convert_cache/convert_cache_base.py @@ -0,0 +1,28 @@ +""" +Disk-based converted model cache. +""" +from abc import ABC, abstractmethod +from pathlib import Path + +class ModelConvertCacheBase(ABC): + + @property + @abstractmethod + def max_size(self) -> float: + """Return the maximum size of this cache directory.""" + pass + + @abstractmethod + def make_room(self, size: float) -> None: + """ + Make sufficient room in the cache directory for a model of max_size. + + :param size: Size required (GB) + """ + pass + + @abstractmethod + def cache_path(self, key: str) -> Path: + """Return the path for a model with the indicated key.""" + pass + diff --git a/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py b/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py new file mode 100644 index 00000000000..f799510ec5b --- /dev/null +++ b/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py @@ -0,0 +1,64 @@ +""" +Placeholder for convert cache implementation. +""" + +from pathlib import Path +import shutil +from invokeai.backend.util.logging import InvokeAILogger +from invokeai.backend.util import GIG, directory_size +from .convert_cache_base import ModelConvertCacheBase + +class ModelConvertCache(ModelConvertCacheBase): + + def __init__(self, cache_path: Path, max_size: float=10.0): + """Initialize the convert cache with the base directory and a limit on its maximum size (in GBs).""" + if not cache_path.exists(): + cache_path.mkdir(parents=True) + self._cache_path = cache_path + self._max_size = max_size + + @property + def max_size(self) -> float: + """Return the maximum size of this cache directory (GB).""" + return self._max_size + + def cache_path(self, key: str) -> Path: + """Return the path for a model with the indicated key.""" + return self._cache_path / key + + def make_room(self, size: float) -> None: + """ + Make sufficient room in the cache directory for a model of max_size. + + :param size: Size required (GB) + """ + size_needed = directory_size(self._cache_path) + size + max_size = int(self.max_size) * GIG + logger = InvokeAILogger.get_logger() + + if size_needed <= max_size: + return + + logger.debug( + f"Convert cache has gotten too large {(size_needed / GIG):4.2f} > {(max_size / GIG):4.2f}G.. Trimming." + ) + + # For this to work, we make the assumption that the directory contains + # a 'model_index.json', 'unet/config.json' file, or a 'config.json' file at top level. + # This should be true for any diffusers model. + def by_atime(path: Path) -> float: + for config in ["model_index.json", "unet/config.json", "config.json"]: + sentinel = path / config + if sentinel.exists(): + return sentinel.stat().st_atime + return 0.0 + + # sort by last access time - least accessed files will be at the end + lru_models = sorted(self._cache_path.iterdir(), key=by_atime, reverse=True) + logger.debug(f"cached models in descending atime order: {lru_models}") + while size_needed > max_size and len(lru_models) > 0: + next_victim = lru_models.pop() + victim_size = directory_size(next_victim) + logger.debug(f"Removing cached converted model {next_victim} to free {victim_size / GIG} GB") + shutil.rmtree(next_victim) + size_needed -= victim_size diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index 7cb7222b717..3ade83160a2 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -16,39 +16,11 @@ from pathlib import Path from typing import Any, Callable, Dict, Optional, Type, Union -import torch -from diffusers import DiffusionPipeline -from injector import inject - from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.model_records import ModelRecordServiceBase -from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType -from invokeai.backend.model_manager.convert_cache import ModelConvertCacheBase -from invokeai.backend.model_manager.onnx_runtime import IAIOnnxRuntimeModel -from invokeai.backend.model_manager.ram_cache import ModelCacheBase - -AnyModel = Union[DiffusionPipeline, torch.nn.Module, IAIOnnxRuntimeModel] - - -class ModelLockerBase(ABC): - """Base class for the model locker used by the loader.""" - - @abstractmethod - def lock(self) -> None: - """Lock the contained model and move it into VRAM.""" - pass - - @abstractmethod - def unlock(self) -> None: - """Unlock the contained model, and remove it from VRAM.""" - pass - - @property - @abstractmethod - def model(self) -> AnyModel: - """Return the model.""" - pass - +from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType +from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase +from invokeai.backend.model_manager.load.model_cache.model_locker import ModelLockerBase +from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase @dataclass class LoadedModel: @@ -69,7 +41,7 @@ def __exit__(self, *args: Any, **kwargs: Any) -> None: @property def model(self) -> AnyModel: """Return the model without locking it.""" - return self.locker.model() + return self.locker.model class ModelLoaderBase(ABC): @@ -89,9 +61,9 @@ def __init__( @abstractmethod def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: """ - Return a model given its key. + Return a model given its confguration. - Given a model key identified in the model configuration backend, + Given a model identified in the model configuration backend, return a ModelInfo object that can be used to retrieve the model. :param model_config: Model configuration, as returned by ModelConfigRecordStore @@ -115,34 +87,32 @@ class AnyModelLoader: # this tracks the loader subclasses _registry: Dict[str, Type[ModelLoaderBase]] = {} - @inject def __init__( self, - store: ModelRecordServiceBase, app_config: InvokeAIAppConfig, logger: Logger, ram_cache: ModelCacheBase, convert_cache: ModelConvertCacheBase, ): - """Store the provided ModelRecordServiceBase and empty the registry.""" - self._store = store + """Initialize AnyModelLoader with its dependencies.""" self._app_config = app_config self._logger = logger self._ram_cache = ram_cache self._convert_cache = convert_cache - def get_model(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel: - """ - Return a model given its key. + @property + def ram_cache(self) -> ModelCacheBase: + """Return the RAM cache associated used by the loaders.""" + return self._ram_cache - Given a model key identified in the model configuration backend, - return a ModelInfo object that can be used to retrieve the model. + def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType]=None) -> LoadedModel: + """ + Return a model given its configuration. :param key: model key, as known to the config backend :param submodel_type: an ModelType enum indicating the portion of the model to retrieve (e.g. ModelType.Vae) """ - model_config = self._store.get_model(key) implementation = self.__class__.get_implementation( base=model_config.base, type=model_config.type, format=model_config.format ) @@ -165,7 +135,7 @@ def get_implementation(cls, base: BaseModelType, type: ModelType, format: ModelF implementation = cls._registry.get(key1) or cls._registry.get(key2) if not implementation: raise NotImplementedError( - "No subclass of LoadedModel is registered for base={base}, type={type}, format={format}" + f"No subclass of LoadedModel is registered for base={base}, type={type}, format={format}" ) return implementation @@ -176,18 +146,10 @@ def register( """Define a decorator which registers the subclass of loader.""" def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]: - print("Registering class", subclass.__name__) + print("DEBUG: Registering class", subclass.__name__) key = cls._to_registry_key(base, type, format) cls._registry[key] = subclass return subclass return decorator - -# in _init__.py will call something like -# def configure_loader_dependencies(binder): -# binder.bind(ModelRecordServiceBase, ApiDependencies.invoker.services.model_records, scope=singleton) -# binder.bind(InvokeAIAppConfig, ApiDependencies.invoker.services.configuration, scope=singleton) -# etc -# injector = Injector(configure_loader_dependencies) -# loader = injector.get(ModelFactory) diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index eb2d432aaae..0b028235fdd 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -8,15 +8,14 @@ from diffusers import ModelMixin from diffusers.configuration_utils import ConfigMixin -from injector import inject from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.model_manager import AnyModelConfig, InvalidModelConfigException, ModelRepoVariant, SubModelType -from invokeai.backend.model_manager.convert_cache import ModelConvertCacheBase +from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.load_base import AnyModel, LoadedModel, ModelLoaderBase from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init -from invokeai.backend.model_manager.ram_cache import ModelCacheBase, ModelLockerBase +from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase from invokeai.backend.util.devices import choose_torch_device, torch_dtype @@ -35,7 +34,6 @@ def load_config(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]: class ModelLoader(ModelLoaderBase): """Default implementation of ModelLoaderBase.""" - @inject # can inject instances of each of the classes in the call signature def __init__( self, app_config: InvokeAIAppConfig, @@ -87,18 +85,15 @@ def _get_model_path( def _convert_if_needed( self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None ) -> Path: - if not self._needs_conversion(config): - return model_path - - self._convert_cache.make_room(self._size or self.get_size_fs(config, model_path, submodel_type)) cache_path: Path = self._convert_cache.cache_path(config.key) - if cache_path.exists(): - return cache_path - self._convert_model(model_path, cache_path) - return cache_path + if not self._needs_conversion(config, model_path, cache_path): + return cache_path if cache_path.exists() else model_path + + self._convert_cache.make_room(self._size or self.get_size_fs(config, model_path, submodel_type)) + return self._convert_model(config, model_path, cache_path) - def _needs_conversion(self, config: AnyModelConfig) -> bool: + def _needs_conversion(self, config: AnyModelConfig, model_path: Path, cache_path: Path) -> bool: return False def _load_if_needed( @@ -133,7 +128,7 @@ def get_size_fs( variant=config.repo_variant if hasattr(config, "repo_variant") else None, ) - def _convert_model(self, model_path: Path, cache_path: Path) -> None: + def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path: raise NotImplementedError def _load_model( diff --git a/invokeai/backend/model_manager/load/model_cache/__init__.py b/invokeai/backend/model_manager/load/model_cache/__init__.py new file mode 100644 index 00000000000..776b9d8936d --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/__init__.py @@ -0,0 +1,5 @@ +"""Init file for RamCache.""" + +from .model_cache_base import ModelCacheBase +from .model_cache_default import ModelCache +_all__ = ['ModelCacheBase', 'ModelCache'] diff --git a/invokeai/backend/model_manager/load/ram_cache/ram_cache_base.py b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py similarity index 77% rename from invokeai/backend/model_manager/load/ram_cache/ram_cache_base.py rename to invokeai/backend/model_manager/load/model_cache/model_cache_base.py index cd80d1e78b2..50b69d961c6 100644 --- a/invokeai/backend/model_manager/load/ram_cache/ram_cache_base.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py @@ -10,34 +10,41 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from logging import Logger -from typing import Dict, Optional +from typing import Dict, Optional, TypeVar, Generic import torch -from invokeai.backend.model_manager import SubModelType -from invokeai.backend.model_manager.load.load_base import AnyModel, ModelLockerBase +from invokeai.backend.model_manager import AnyModel, SubModelType +class ModelLockerBase(ABC): + """Base class for the model locker used by the loader.""" -@dataclass -class CacheStats(object): - """Data object to record statistics on cache hits/misses.""" + @abstractmethod + def lock(self) -> AnyModel: + """Lock the contained model and move it into VRAM.""" + pass - hits: int = 0 # cache hits - misses: int = 0 # cache misses - high_watermark: int = 0 # amount of cache used - in_cache: int = 0 # number of models in cache - cleared: int = 0 # number of models cleared to make space - cache_size: int = 0 # total size of cache - loaded_model_sizes: Dict[str, int] = field(default_factory=dict) + @abstractmethod + def unlock(self) -> None: + """Unlock the contained model, and remove it from VRAM.""" + pass + + @property + @abstractmethod + def model(self) -> AnyModel: + """Return the model.""" + pass +T = TypeVar("T") @dataclass -class CacheRecord: +class CacheRecord(Generic[T]): """Elements of the cache.""" key: str - model: AnyModel + model: T size: int + loaded: bool = False _locks: int = 0 def lock(self) -> None: @@ -55,7 +62,7 @@ def locked(self) -> bool: return self._locks > 0 -class ModelCacheBase(ABC): +class ModelCacheBase(ABC, Generic[T]): """Virtual base class for RAM model cache.""" @property @@ -76,8 +83,14 @@ def lazy_offloading(self) -> bool: """Return true if the cache is configured to lazily offload models in VRAM.""" pass + @property @abstractmethod - def offload_unlocked_models(self) -> None: + def max_cache_size(self) -> float: + """Return true if the cache is configured to lazily offload models in VRAM.""" + pass + + @abstractmethod + def offload_unlocked_models(self, size_required: int) -> None: """Offload from VRAM any models not actively in use.""" pass @@ -101,7 +114,7 @@ def make_room(self, size: int) -> None: def put( self, key: str, - model: AnyModel, + model: T, submodel_type: Optional[SubModelType] = None, ) -> None: """Store model under key and optional submodel_type.""" @@ -134,11 +147,6 @@ def cache_size(self) -> int: """Get the total size of the models currently cached.""" pass - @abstractmethod - def get_stats(self) -> CacheStats: - """Return cache hit/miss/size statistics.""" - pass - @abstractmethod def print_cuda_stats(self) -> None: """Log debugging information on CUDA usage.""" diff --git a/invokeai/backend/model_manager/load/ram_cache/ram_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py similarity index 63% rename from invokeai/backend/model_manager/load/ram_cache/ram_cache_default.py rename to invokeai/backend/model_manager/load/model_cache/model_cache_default.py index bd43e978c83..961f68a4bea 100644 --- a/invokeai/backend/model_manager/load/ram_cache/ram_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -18,6 +18,7 @@ """ +import gc import math import time from contextlib import suppress @@ -26,14 +27,14 @@ import torch -from invokeai.app.services.model_records import UnknownModelException from invokeai.backend.model_manager import SubModelType -from invokeai.backend.model_manager.load.load_base import AnyModel, ModelLockerBase +from invokeai.backend.model_manager.load.load_base import AnyModel from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data -from invokeai.backend.model_manager.load.ram_cache.ram_cache_base import CacheRecord, CacheStats, ModelCacheBase from invokeai.backend.util.devices import choose_torch_device from invokeai.backend.util.logging import InvokeAILogger +from .model_cache_base import CacheRecord, ModelCacheBase +from .model_locker import ModelLockerBase, ModelLocker if choose_torch_device() == torch.device("mps"): from torch import mps @@ -52,7 +53,7 @@ MB = 2**20 -class ModelCache(ModelCacheBase): +class ModelCache(ModelCacheBase[AnyModel]): """Implementation of ModelCacheBase.""" def __init__( @@ -92,62 +93,9 @@ def __init__( self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__) self._log_memory_usage = log_memory_usage - # used for stats collection - self.stats = None - - self._cached_models: Dict[str, CacheRecord] = {} + self._cached_models: Dict[str, CacheRecord[AnyModel]] = {} self._cache_stack: List[str] = [] - class ModelLocker(ModelLockerBase): - """Internal class that mediates movement in and out of GPU.""" - - def __init__(self, cache: ModelCacheBase, cache_entry: CacheRecord): - """ - Initialize the model locker. - - :param cache: The ModelCache object - :param cache_entry: The entry in the model cache - """ - self._cache = cache - self._cache_entry = cache_entry - - @property - def model(self) -> AnyModel: - """Return the model without moving it around.""" - return self._cache_entry.model - - def lock(self) -> Any: - """Move the model into the execution device (GPU) and lock it.""" - if not hasattr(self.model, "to"): - return self.model - - # NOTE that the model has to have the to() method in order for this code to move it into GPU! - self._cache_entry.lock() - - try: - if self._cache.lazy_offloading: - self._cache.offload_unlocked_models() - - self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device) - - self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}") - self._cache.print_cuda_stats() - - except Exception: - self._cache_entry.unlock() - raise - return self.model - - def unlock(self) -> None: - """Call upon exit from context.""" - if not hasattr(self.model, "to"): - return - - self._cache_entry.unlock() - if not self._cache.lazy_offloading: - self._cache.offload_unlocked_models() - self._cache.print_cuda_stats() - @property def logger(self) -> Logger: """Return the logger used by the cache.""" @@ -168,6 +116,11 @@ def execution_device(self) -> torch.device: """Return the exection device (e.g. "cuda" for VRAM).""" return self._execution_device + @property + def max_cache_size(self) -> float: + """Return the cap on cache size.""" + return self._max_cache_size + def cache_size(self) -> int: """Get the total size of the models currently cached.""" total = 0 @@ -207,18 +160,18 @@ def get( """ Retrieve model using key and optional submodel_type. - This may return an UnknownModelException if the model is not in the cache. + This may return an IndexError if the model is not in the cache. """ key = self._make_cache_key(key, submodel_type) if key not in self._cached_models: - raise UnknownModelException + raise IndexError(f"The model with key {key} is not in the cache.") # this moves the entry to the top (right end) of the stack with suppress(Exception): self._cache_stack.remove(key) self._cache_stack.append(key) cache_entry = self._cached_models[key] - return self.ModelLocker( + return ModelLocker( cache=self, cache_entry=cache_entry, ) @@ -234,19 +187,19 @@ def _make_cache_key(self, model_key: str, submodel_type: Optional[SubModelType] else: return model_key - def offload_unlocked_models(self) -> None: + def offload_unlocked_models(self, size_required: int) -> None: """Move any unused models from VRAM.""" reserved = self._max_vram_cache_size * GIG - vram_in_use = torch.cuda.memory_allocated() - self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB") + vram_in_use = torch.cuda.memory_allocated() + size_required + self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB") for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size): if vram_in_use <= reserved: break if not cache_entry.locked: self.move_model_to_device(cache_entry, self.storage_device) - - vram_in_use = torch.cuda.memory_allocated() - self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB") + cache_entry.loaded = False + vram_in_use = torch.cuda.memory_allocated() + size_required + self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM now available for models; max allowed={(reserved/GIG):.2f}GB") torch.cuda.empty_cache() if choose_torch_device() == torch.device("mps"): @@ -305,28 +258,111 @@ def move_model_to_device(self, cache_entry: CacheRecord, target_device: torch.de def print_cuda_stats(self) -> None: """Log CUDA diagnostics.""" vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG) - ram = "%4.2fG" % self.cache_size() + ram = "%4.2fG" % (self.cache_size() / GIG) - cached_models = 0 - loaded_models = 0 - locked_models = 0 + in_ram_models = 0 + in_vram_models = 0 + locked_in_vram_models = 0 for cache_record in self._cached_models.values(): - cached_models += 1 assert hasattr(cache_record.model, "device") - if cache_record.model.device is self.storage_device: - loaded_models += 1 + if cache_record.model.device == self.storage_device: + in_ram_models += 1 + else: + in_vram_models += 1 if cache_record.locked: - locked_models += 1 + locked_in_vram_models += 1 self.logger.debug( - f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ =" - f" {cached_models}/{loaded_models}/{locked_models}" + f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) =" + f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})" ) - def get_stats(self) -> CacheStats: - """Return cache hit/miss/size statistics.""" - raise NotImplementedError - - def make_room(self, size: int) -> None: + def make_room(self, model_size: int) -> None: """Make enough room in the cache to accommodate a new model of indicated size.""" - raise NotImplementedError + # calculate how much memory this model will require + # multiplier = 2 if self.precision==torch.float32 else 1 + bytes_needed = model_size + maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes + current_size = self.cache_size() + + if current_size + bytes_needed > maximum_size: + self.logger.debug( + f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional" + f" {(bytes_needed/GIG):.2f} GB" + ) + + self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}") + + pos = 0 + models_cleared = 0 + while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack): + model_key = self._cache_stack[pos] + cache_entry = self._cached_models[model_key] + + refs = sys.getrefcount(cache_entry.model) + + # HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly + # going against the advice in the Python docs by using `gc.get_referrers(...)` in this way: + # https://docs.python.org/3/library/gc.html#gc.get_referrers + + # manualy clear local variable references of just finished function calls + # for some reason python don't want to collect it even by gc.collect() immidiately + if refs > 2: + while True: + cleared = False + for referrer in gc.get_referrers(cache_entry.model): + if type(referrer).__name__ == "frame": + # RuntimeError: cannot clear an executing frame + with suppress(RuntimeError): + referrer.clear() + cleared = True + # break + + # repeat if referrers changes(due to frame clear), else exit loop + if cleared: + gc.collect() + else: + break + + device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None + self.logger.debug( + f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}," + f" refs: {refs}" + ) + + # Expected refs: + # 1 from cache_entry + # 1 from getrefcount function + # 1 from onnx runtime object + if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2): + self.logger.debug( + f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)" + ) + current_size -= cache_entry.size + models_cleared += 1 + del self._cache_stack[pos] + del self._cached_models[model_key] + del cache_entry + + else: + pos += 1 + + if models_cleared > 0: + # There would likely be some 'garbage' to be collected regardless of whether a model was cleared or not, but + # there is a significant time cost to calling `gc.collect()`, so we want to use it sparingly. (The time cost + # is high even if no garbage gets collected.) + # + # Calling gc.collect(...) when a model is cleared seems like a good middle-ground: + # - If models had to be cleared, it's a signal that we are close to our memory limit. + # - If models were cleared, there's a good chance that there's a significant amount of garbage to be + # collected. + # + # Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up + # immediately when their reference count hits 0. + gc.collect() + + torch.cuda.empty_cache() + if choose_torch_device() == torch.device("mps"): + mps.empty_cache() + + self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}") diff --git a/invokeai/backend/model_manager/load/model_cache/model_locker.py b/invokeai/backend/model_manager/load/model_cache/model_locker.py new file mode 100644 index 00000000000..506d0129491 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/model_locker.py @@ -0,0 +1,59 @@ +""" +Base class and implementation of a class that moves models in and out of VRAM. +""" + +from abc import ABC, abstractmethod +from invokeai.backend.model_manager import AnyModel +from .model_cache_base import ModelLockerBase, ModelCacheBase, CacheRecord + +class ModelLocker(ModelLockerBase): + """Internal class that mediates movement in and out of GPU.""" + + def __init__(self, cache: ModelCacheBase[AnyModel], cache_entry: CacheRecord[AnyModel]): + """ + Initialize the model locker. + + :param cache: The ModelCache object + :param cache_entry: The entry in the model cache + """ + self._cache = cache + self._cache_entry = cache_entry + + @property + def model(self) -> AnyModel: + """Return the model without moving it around.""" + return self._cache_entry.model + + def lock(self) -> AnyModel: + """Move the model into the execution device (GPU) and lock it.""" + if not hasattr(self.model, "to"): + return self.model + + # NOTE that the model has to have the to() method in order for this code to move it into GPU! + self._cache_entry.lock() + + try: + if self._cache.lazy_offloading: + self._cache.offload_unlocked_models(self._cache_entry.size) + + self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device) + self._cache_entry.loaded = True + + self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}") + self._cache.print_cuda_stats() + + except Exception: + self._cache_entry.unlock() + raise + return self.model + + def unlock(self) -> None: + """Call upon exit from context.""" + if not hasattr(self.model, "to"): + return + + self._cache_entry.unlock() + if not self._cache.lazy_offloading: + self._cache.offload_unlocked_models(self._cache_entry.size) + self._cache.print_cuda_stats() + diff --git a/invokeai/backend/model_manager/load/model_loaders/__init__.py b/invokeai/backend/model_manager/load/model_loaders/__init__.py new file mode 100644 index 00000000000..962cba54811 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/__init__.py @@ -0,0 +1,3 @@ +""" +Init file for model_loaders. +""" diff --git a/invokeai/backend/model_manager/load/model_loaders/vae.py b/invokeai/backend/model_manager/load/model_loaders/vae.py new file mode 100644 index 00000000000..6f21c3d0903 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/vae.py @@ -0,0 +1,83 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Class for VAE model loading in InvokeAI.""" + +from pathlib import Path +from typing import Optional + +import torch +import safetensors +from omegaconf import OmegaConf, DictConfig +from invokeai.backend.util.devices import torch_dtype +from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelRepoVariant, ModelType, SubModelType +from invokeai.backend.model_manager.load.load_base import AnyModelLoader +from invokeai.backend.model_manager.load.load_default import ModelLoader +from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers + +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers) +@AnyModelLoader.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint) +@AnyModelLoader.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint) +class VaeDiffusersModel(ModelLoader): + """Class to load VAE models.""" + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if submodel_type is not None: + raise Exception("There are no submodels in VAEs") + vae_class = self._get_hf_load_class(model_path) + variant = model_variant.value if model_variant else None + result: AnyModel = vae_class.from_pretrained( + model_path, torch_dtype=self._torch_dtype, variant=variant + ) # type: ignore + return result + + def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool: + print(f'DEBUG: last_modified={config.last_modified}') + print(f'DEBUG: cache_path={(dest_path / "config.json").stat().st_mtime}') + print(f'DEBUG: model_path={model_path.stat().st_mtime}') + if config.format != ModelFormat.Checkpoint: + return False + elif dest_path.exists() \ + and (dest_path / "config.json").stat().st_mtime >= config.last_modified \ + and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime: + return False + else: + return True + + def _convert_model(self, + config: AnyModelConfig, + weights_path: Path, + output_path: Path + ) -> Path: + if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}: + raise Exception(f"Vae conversion not supported for model type: {config.base}") + else: + config_file = 'v1-inference.yaml' if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml" + + if weights_path.suffix == ".safetensors": + checkpoint = safetensors.torch.load_file(weights_path, device="cpu") + else: + checkpoint = torch.load(weights_path, map_location="cpu") + + dtype = torch_dtype() + + # sometimes weights are hidden under "state_dict", and sometimes not + if "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + ckpt_config = OmegaConf.load(self._app_config.legacy_conf_path / config_file) + assert isinstance(ckpt_config, DictConfig) + + print(f'DEBUG: CONVERTIGN') + vae_model = convert_ldm_vae_to_diffusers( + checkpoint=checkpoint, + vae_config=ckpt_config, + image_size=512, + ) + vae_model.to(dtype) # set precision appropriately + vae_model.save_pretrained(output_path, safe_serialization=True, torch_dtype=dtype) + return output_path + diff --git a/invokeai/backend/model_manager/load/model_util.py b/invokeai/backend/model_manager/load/model_util.py index 18407cbca2e..7c27e66472f 100644 --- a/invokeai/backend/model_manager/load/model_util.py +++ b/invokeai/backend/model_manager/load/model_util.py @@ -48,6 +48,9 @@ def _calc_onnx_model_by_data(model: IAIOnnxRuntimeModel) -> int: def calc_model_size_by_fs(model_path: Path, subfolder: Optional[str] = None, variant: Optional[str] = None) -> int: """Estimate the size of a model on disk in bytes.""" + if model_path.is_file(): + return model_path.stat().st_size + if subfolder is not None: model_path = model_path / subfolder diff --git a/invokeai/backend/model_manager/load/ram_cache/__init__.py b/invokeai/backend/model_manager/load/ram_cache/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/invokeai/backend/model_manager/load/vae.py b/invokeai/backend/model_manager/load/vae.py deleted file mode 100644 index a6cbe241e1e..00000000000 --- a/invokeai/backend/model_manager/load/vae.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team -"""Class for VAE model loading in InvokeAI.""" - -from pathlib import Path -from typing import Dict, Optional - -import torch - -from invokeai.backend.model_manager import BaseModelType, ModelFormat, ModelRepoVariant, ModelType, SubModelType -from invokeai.backend.model_manager.load.load_base import AnyModelLoader -from invokeai.backend.model_manager.load.load_default import ModelLoader - - -@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers) -class VaeDiffusersModel(ModelLoader): - """Class to load VAE models.""" - - def _load_model( - self, - model_path: Path, - model_variant: Optional[ModelRepoVariant] = None, - submodel_type: Optional[SubModelType] = None, - ) -> Dict[str, torch.Tensor]: - if submodel_type is not None: - raise Exception("There are no submodels in VAEs") - vae_class = self._get_hf_load_class(model_path) - variant = model_variant.value if model_variant else "" - result: Dict[str, torch.Tensor] = vae_class.from_pretrained( - model_path, torch_dtype=self._torch_dtype, variant=variant - ) # type: ignore - return result diff --git a/invokeai/backend/util/__init__.py b/invokeai/backend/util/__init__.py index 87ae1480f54..0164dffe303 100644 --- a/invokeai/backend/util/__init__.py +++ b/invokeai/backend/util/__init__.py @@ -12,6 +12,14 @@ torch_dtype, ) from .logging import InvokeAILogger -from .util import Chdir, ask_user, download_with_resume, instantiate_from_config, url_attachment_name # noqa: F401 +from .util import ( # TO DO: Clean this up; remove the unused symbols + GIG, + Chdir, + ask_user, # noqa + directory_size, + download_with_resume, + instantiate_from_config, # noqa + url_attachment_name, # noqa + ) -__all__ = ["Chdir", "InvokeAILogger", "choose_precision", "choose_torch_device"] +__all__ = ["GIG", "directory_size","Chdir", "download_with_resume", "InvokeAILogger", "choose_precision", "choose_torch_device"] diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index d6d3ad727f7..ad3f4e139a7 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -1,7 +1,7 @@ from __future__ import annotations from contextlib import nullcontext -from typing import Union +from typing import Union, Optional import torch from torch import autocast @@ -43,7 +43,8 @@ def choose_precision(device: torch.device) -> str: return "float32" -def torch_dtype(device: torch.device) -> torch.dtype: +def torch_dtype(device: Optional[torch.device] = None) -> torch.dtype: + device = device or choose_torch_device() precision = choose_precision(device) if precision == "float16": return torch.float16 diff --git a/invokeai/backend/util/util.py b/invokeai/backend/util/util.py index 13751e27702..6589aa72784 100644 --- a/invokeai/backend/util/util.py +++ b/invokeai/backend/util/util.py @@ -24,6 +24,20 @@ from .devices import torch_dtype +# actual size of a gig +GIG = 1073741824 + +def directory_size(directory: Path) -> int: + """ + Return the aggregate size of all files in a directory (bytes). + """ + sum = 0 + for root, dirs, files in os.walk(directory): + for f in files: + sum += Path(root, f).stat().st_size + for d in dirs: + sum += Path(root, d).stat().st_size + return sum def log_txt_as_img(wh, xc, size=10): # wh a tuple of (width, height) From 34d5cad4c9fbe9919ccfc99cbe2a18278a9447d3 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 4 Feb 2024 17:23:10 -0500 Subject: [PATCH 075/100] loaders for main, controlnet, ip-adapter, clipvision and t2i --- .../app/services/config/config_default.py | 2 +- .../model_records/model_records_base.py | 11 +- .../model_records/model_records_sql.py | 17 +- .../sqlite_migrator/migrations/migration_6.py | 5 +- invokeai/backend/install/install_helper.py | 1 + .../model_management/models/controlnet.py | 1 - invokeai/backend/model_manager/__init__.py | 2 +- invokeai/backend/model_manager/config.py | 15 +- .../convert_ckpt_to_diffusers.py | 4 +- .../backend/model_manager/load/__init__.py | 24 +- .../load/convert_cache/__init__.py | 2 +- .../load/convert_cache/convert_cache_base.py | 3 +- .../convert_cache/convert_cache_default.py | 10 +- .../backend/model_manager/load/load_base.py | 52 +- .../model_manager/load/load_default.py | 50 +- .../model_manager/load/memory_snapshot.py | 2 +- .../load/model_cache/__init__.py | 2 +- .../load/model_cache/model_cache_base.py | 8 +- .../load/model_cache/model_cache_default.py | 46 +- .../load/model_cache/model_locker.py | 6 +- .../load/model_loaders/controlnet.py | 60 ++ .../load/model_loaders/generic_diffusers.py | 34 + .../load/model_loaders/ip_adapter.py | 39 ++ .../model_manager/load/model_loaders/lora.py | 76 +++ .../load/model_loaders/stable_diffusion.py | 93 +++ .../model_manager/load/model_loaders/vae.py | 66 +- .../backend/model_manager/load/model_util.py | 5 +- invokeai/backend/model_manager/lora.py | 620 ++++++++++++++++++ invokeai/backend/model_manager/probe.py | 6 +- invokeai/backend/util/__init__.py | 16 +- invokeai/backend/util/devices.py | 2 +- invokeai/backend/util/util.py | 2 + 32 files changed, 1123 insertions(+), 159 deletions(-) create mode 100644 invokeai/backend/model_manager/load/model_loaders/controlnet.py create mode 100644 invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py create mode 100644 invokeai/backend/model_manager/load/model_loaders/ip_adapter.py create mode 100644 invokeai/backend/model_manager/load/model_loaders/lora.py create mode 100644 invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py create mode 100644 invokeai/backend/model_manager/lora.py diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index b161ea18d61..b39e916da34 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -173,7 +173,7 @@ class InvokeBatch(InvokeAISettings): import os from pathlib import Path -from typing import Any, ClassVar, Dict, List, Literal, Optional, Union +from typing import Any, ClassVar, Dict, List, Literal, Optional from omegaconf import DictConfig, OmegaConf from pydantic import Field diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index 31cfecb4ec8..42e3c8f83a7 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -11,7 +11,14 @@ from pydantic import BaseModel, Field from invokeai.app.services.shared.pagination import PaginatedResults -from invokeai.backend.model_manager import LoadedModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType +from invokeai.backend.model_manager import ( + AnyModelConfig, + BaseModelType, + LoadedModel, + ModelFormat, + ModelType, + SubModelType, +) from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore @@ -108,7 +115,7 @@ def load_model(self, key: str, submodel_type: Optional[SubModelType]) -> LoadedM Load the indicated model into memory and return a LoadedModel object. :param key: Key of model config to be fetched. - :param submodel_type: For main (pipeline models), the submodel to fetch + :param submodel_type: For main (pipeline models), the submodel to fetch Exceptions: UnknownModelException -- model with this key not known NotImplementedException -- a model loader was not provided at initialization time diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index eee867ccb46..b50cd17a75d 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -42,7 +42,6 @@ import json import sqlite3 -import time from math import ceil from pathlib import Path from typing import Any, Dict, List, Optional, Set, Tuple, Union @@ -56,8 +55,8 @@ ModelType, SubModelType, ) -from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel +from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException from ..shared.sqlite.sqlite_database import SqliteDatabase from .model_records_base import ( @@ -72,7 +71,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): """Implementation of the ModelConfigStore ABC using a SQL database.""" - def __init__(self, db: SqliteDatabase, loader: Optional[AnyModelLoader]=None): + def __init__(self, db: SqliteDatabase, loader: Optional[AnyModelLoader] = None): """ Initialize a new object from preexisting sqlite3 connection and threading lock objects. @@ -289,7 +288,9 @@ def search_by_attr( """, tuple(bindings), ) - results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()] + results = [ + ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall() + ] return results def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]: @@ -303,7 +304,9 @@ def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]: """, (str(path),), ) - results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()] + results = [ + ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall() + ] return results def search_by_hash(self, hash: str) -> List[AnyModelConfig]: @@ -317,7 +320,9 @@ def search_by_hash(self, hash: str) -> List[AnyModelConfig]: """, (hash,), ) - results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()] + results = [ + ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall() + ] return results @property diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_6.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_6.py index e72878f726f..b4734445110 100644 --- a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_6.py +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_6.py @@ -1,11 +1,9 @@ import sqlite3 -from logging import Logger -from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration -class Migration6Callback: +class Migration6Callback: def __call__(self, cursor: sqlite3.Cursor) -> None: self._recreate_model_triggers(cursor) @@ -28,6 +26,7 @@ def _recreate_model_triggers(self, cursor: sqlite3.Cursor) -> None: """ ) + def build_migration_6() -> Migration: """ Build the migration from database version 5 to 6. diff --git a/invokeai/backend/install/install_helper.py b/invokeai/backend/install/install_helper.py index 8c03d2ccf84..9f219132d4d 100644 --- a/invokeai/backend/install/install_helper.py +++ b/invokeai/backend/install/install_helper.py @@ -120,6 +120,7 @@ def dispatch(self, event_name: str, payload: Any) -> None: elif payload["event"] == "model_install_cancelled": self._logger.warning(f"{source}: installation cancelled") + class InstallHelper(object): """Capture information stored jointly in INITIAL_MODELS.yaml and the installed models db.""" diff --git a/invokeai/backend/model_management/models/controlnet.py b/invokeai/backend/model_management/models/controlnet.py index da269eba4b7..3b534cb9d14 100644 --- a/invokeai/backend/model_management/models/controlnet.py +++ b/invokeai/backend/model_management/models/controlnet.py @@ -139,7 +139,6 @@ def _convert_controlnet_ckpt_and_cache( cache it to disk, and return Path to converted file. If already on disk then just returns Path. """ - print(f"DEBUG: controlnet config = {model_config}") app_config = InvokeAIAppConfig.get_config() weights = app_config.root_path / model_path output_path = Path(output_path) diff --git a/invokeai/backend/model_manager/__init__.py b/invokeai/backend/model_manager/__init__.py index f3c84cd01f7..98cc5054c73 100644 --- a/invokeai/backend/model_manager/__init__.py +++ b/invokeai/backend/model_manager/__init__.py @@ -13,9 +13,9 @@ SchedulerPredictionType, SubModelType, ) +from .load import LoadedModel from .probe import ModelProbe from .search import ModelSearch -from .load import LoadedModel __all__ = [ "AnyModel", diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 796ccbacde0..e59a84d7291 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -20,14 +20,16 @@ """ import time -import torch from enum import Enum from typing import Literal, Optional, Type, Union -from pydantic import BaseModel, ConfigDict, Field, TypeAdapter +import torch from diffusers import ModelMixin +from pydantic import BaseModel, ConfigDict, Field, TypeAdapter from typing_extensions import Annotated, Any, Dict + from .onnx_runtime import IAIOnnxRuntimeModel +from ..ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus class InvalidModelConfigException(Exception): """Exception for when config parser doesn't recognized this combination of model type and format.""" @@ -204,6 +206,8 @@ class _MainConfig(ModelConfigBase): vae: Optional[str] = Field(default=None) variant: ModelVariantType = ModelVariantType.Normal + prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon + upcast_attention: bool = False ztsnr_training: bool = False @@ -217,8 +221,6 @@ class MainDiffusersConfig(_DiffusersConfig, _MainConfig): """Model config for main diffusers models.""" type: Literal[ModelType.Main] = ModelType.Main - prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon - upcast_attention: bool = False class ONNXSD1Config(_MainConfig): @@ -276,6 +278,7 @@ class T2IConfig(ModelConfigBase): _ONNXConfig, _VaeConfig, _ControlNetConfig, + # ModelConfigBase, LoRAConfig, TextualInversionConfig, IPAdapterConfig, @@ -284,7 +287,7 @@ class T2IConfig(ModelConfigBase): ] AnyModelConfigValidator = TypeAdapter(AnyModelConfig) -AnyModel = Union[ModelMixin, torch.nn.Module, IAIOnnxRuntimeModel] +AnyModel = Union[ModelMixin, torch.nn.Module, IAIOnnxRuntimeModel, IPAdapter, IPAdapterPlus] # IMPLEMENTATION NOTE: # The preferred alternative to the above is a discriminated Union as shown @@ -317,7 +320,7 @@ def make_config( model_data: Union[dict, AnyModelConfig], key: Optional[str] = None, dest_class: Optional[Type] = None, - timestamp: Optional[float] = None + timestamp: Optional[float] = None, ) -> AnyModelConfig: """ Return the appropriate config object from raw dict values. diff --git a/invokeai/backend/model_manager/convert_ckpt_to_diffusers.py b/invokeai/backend/model_manager/convert_ckpt_to_diffusers.py index 9d6fc4841f2..6f5acd58329 100644 --- a/invokeai/backend/model_manager/convert_ckpt_to_diffusers.py +++ b/invokeai/backend/model_manager/convert_ckpt_to_diffusers.py @@ -43,7 +43,6 @@ UnCLIPScheduler, ) from diffusers.utils import is_accelerate_available -from diffusers.utils.import_utils import BACKENDS_MAPPING from picklescan.scanner import scan_file_path from transformers import ( AutoFeatureExtractor, @@ -58,8 +57,8 @@ ) from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.model_manager import BaseModelType, ModelVariantType +from invokeai.backend.util.logging import InvokeAILogger try: from omegaconf import OmegaConf @@ -1643,7 +1642,6 @@ def download_controlnet_from_original_ckpt( cross_attention_dim: Optional[bool] = None, scan_needed: bool = False, ) -> DiffusionPipeline: - from omegaconf import OmegaConf if from_safetensors: diff --git a/invokeai/backend/model_manager/load/__init__.py b/invokeai/backend/model_manager/load/__init__.py index 357677bb7f7..19b0116ba3b 100644 --- a/invokeai/backend/model_manager/load/__init__.py +++ b/invokeai/backend/model_manager/load/__init__.py @@ -8,14 +8,15 @@ from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.util.logging import InvokeAILogger + +from .convert_cache.convert_cache_default import ModelConvertCache from .load_base import AnyModelLoader, LoadedModel from .model_cache.model_cache_default import ModelCache -from .convert_cache.convert_cache_default import ModelConvertCache # This registers the subclasses that implement loaders of specific model types -loaders = [x.stem for x in Path(Path(__file__).parent,'model_loaders').glob('*.py') if x.stem != '__init__'] +loaders = [x.stem for x in Path(Path(__file__).parent, "model_loaders").glob("*.py") if x.stem != "__init__"] for module in loaders: - print(f'module={module}') + print(f"module={module}") import_module(f"{__package__}.model_loaders.{module}") __all__ = ["AnyModelLoader", "LoadedModel"] @@ -24,12 +25,11 @@ def get_standalone_loader(app_config: Optional[InvokeAIAppConfig]) -> AnyModelLoader: app_config = app_config or InvokeAIAppConfig.get_config() logger = InvokeAILogger.get_logger(config=app_config) - return AnyModelLoader(app_config=app_config, - logger=logger, - ram_cache=ModelCache(logger=logger, - max_cache_size=app_config.ram_cache_size, - max_vram_cache_size=app_config.vram_cache_size - ), - convert_cache=ModelConvertCache(app_config.models_convert_cache_path) - ) - + return AnyModelLoader( + app_config=app_config, + logger=logger, + ram_cache=ModelCache( + logger=logger, max_cache_size=app_config.ram_cache_size, max_vram_cache_size=app_config.vram_cache_size + ), + convert_cache=ModelConvertCache(app_config.models_convert_cache_path), + ) diff --git a/invokeai/backend/model_manager/load/convert_cache/__init__.py b/invokeai/backend/model_manager/load/convert_cache/__init__.py index eb3149be329..5be56d2d584 100644 --- a/invokeai/backend/model_manager/load/convert_cache/__init__.py +++ b/invokeai/backend/model_manager/load/convert_cache/__init__.py @@ -1,4 +1,4 @@ from .convert_cache_base import ModelConvertCacheBase from .convert_cache_default import ModelConvertCache -__all__ = ['ModelConvertCacheBase', 'ModelConvertCache'] +__all__ = ["ModelConvertCacheBase", "ModelConvertCache"] diff --git a/invokeai/backend/model_manager/load/convert_cache/convert_cache_base.py b/invokeai/backend/model_manager/load/convert_cache/convert_cache_base.py index 25263f96aaa..6268c099a5f 100644 --- a/invokeai/backend/model_manager/load/convert_cache/convert_cache_base.py +++ b/invokeai/backend/model_manager/load/convert_cache/convert_cache_base.py @@ -4,8 +4,8 @@ from abc import ABC, abstractmethod from pathlib import Path -class ModelConvertCacheBase(ABC): +class ModelConvertCacheBase(ABC): @property @abstractmethod def max_size(self) -> float: @@ -25,4 +25,3 @@ def make_room(self, size: float) -> None: def cache_path(self, key: str) -> Path: """Return the path for a model with the indicated key.""" pass - diff --git a/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py b/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py index f799510ec5b..4c361258d90 100644 --- a/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py +++ b/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py @@ -2,15 +2,17 @@ Placeholder for convert cache implementation. """ -from pathlib import Path import shutil -from invokeai.backend.util.logging import InvokeAILogger +from pathlib import Path + from invokeai.backend.util import GIG, directory_size +from invokeai.backend.util.logging import InvokeAILogger + from .convert_cache_base import ModelConvertCacheBase -class ModelConvertCache(ModelConvertCacheBase): - def __init__(self, cache_path: Path, max_size: float=10.0): +class ModelConvertCache(ModelConvertCacheBase): + def __init__(self, cache_path: Path, max_size: float = 10.0): """Initialize the convert cache with the base directory and a limit on its maximum size (in GBs).""" if not cache_path.exists(): cache_path.mkdir(parents=True) diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index 3ade83160a2..7d4e8337c3c 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -10,17 +10,19 @@ # do something with loaded_model """ +import hashlib from abc import ABC, abstractmethod from dataclasses import dataclass from logging import Logger from pathlib import Path -from typing import Any, Callable, Dict, Optional, Type, Union +from typing import Any, Callable, Dict, Optional, Tuple, Type from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType -from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase -from invokeai.backend.model_manager.load.model_cache.model_locker import ModelLockerBase +from invokeai.backend.model_manager.config import VaeCheckpointConfig, VaeDiffusersConfig from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase +from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase + @dataclass class LoadedModel: @@ -52,7 +54,7 @@ def __init__( self, app_config: InvokeAIAppConfig, logger: Logger, - ram_cache: ModelCacheBase, + ram_cache: ModelCacheBase[AnyModel], convert_cache: ModelConvertCacheBase, ): """Initialize the loader.""" @@ -91,7 +93,7 @@ def __init__( self, app_config: InvokeAIAppConfig, logger: Logger, - ram_cache: ModelCacheBase, + ram_cache: ModelCacheBase[AnyModel], convert_cache: ModelConvertCacheBase, ): """Initialize AnyModelLoader with its dependencies.""" @@ -101,11 +103,11 @@ def __init__( self._convert_cache = convert_cache @property - def ram_cache(self) -> ModelCacheBase: + def ram_cache(self) -> ModelCacheBase[AnyModel]: """Return the RAM cache associated used by the loaders.""" return self._ram_cache - def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType]=None) -> LoadedModel: + def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: """ Return a model given its configuration. @@ -113,9 +115,7 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo :param submodel_type: an ModelType enum indicating the portion of the model to retrieve (e.g. ModelType.Vae) """ - implementation = self.__class__.get_implementation( - base=model_config.base, type=model_config.type, format=model_config.format - ) + implementation, model_config, submodel_type = self.__class__.get_implementation(model_config, submodel_type) return implementation( app_config=self._app_config, logger=self._logger, @@ -128,16 +128,37 @@ def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) return "-".join([base.value, type.value, format.value]) @classmethod - def get_implementation(cls, base: BaseModelType, type: ModelType, format: ModelFormat) -> Type[ModelLoaderBase]: + def get_implementation( + cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] + ) -> Tuple[Type[ModelLoaderBase], AnyModelConfig, Optional[SubModelType]]: """Get subclass of ModelLoaderBase registered to handle base and type.""" - key1 = cls._to_registry_key(base, type, format) # for a specific base type - key2 = cls._to_registry_key(BaseModelType.Any, type, format) # with wildcard Any + # We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned + conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type) + + key1 = cls._to_registry_key(conf2.base, conf2.type, conf2.format) # for a specific base type + key2 = cls._to_registry_key(BaseModelType.Any, conf2.type, conf2.format) # with wildcard Any implementation = cls._registry.get(key1) or cls._registry.get(key2) if not implementation: raise NotImplementedError( - f"No subclass of LoadedModel is registered for base={base}, type={type}, format={format}" + f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}" ) - return implementation + return implementation, conf2, submodel_type + + @classmethod + def _handle_subtype_overrides( + cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] + ) -> Tuple[AnyModelConfig, Optional[SubModelType]]: + if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None: + model_path = Path(config.vae) + config_class = ( + VaeCheckpointConfig if model_path.suffix in [".pt", ".safetensors", ".ckpt"] else VaeDiffusersConfig + ) + hash = hashlib.md5(model_path.as_posix().encode("utf-8")).hexdigest() + new_conf = config_class(path=model_path.as_posix(), name=model_path.stem, base=config.base, key=hash) + submodel_type = None + else: + new_conf = config + return new_conf, submodel_type @classmethod def register( @@ -152,4 +173,3 @@ def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]: return subclass return decorator - diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index 0b028235fdd..453283e9b4a 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -10,12 +10,12 @@ from diffusers.configuration_utils import ConfigMixin from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_manager import AnyModelConfig, InvalidModelConfigException, ModelRepoVariant, SubModelType +from invokeai.backend.model_manager import AnyModel, AnyModelConfig, InvalidModelConfigException, ModelRepoVariant, SubModelType from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase -from invokeai.backend.model_manager.load.load_base import AnyModel, LoadedModel, ModelLoaderBase -from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs -from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init +from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase +from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs, calc_model_size_by_data +from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init from invokeai.backend.util.devices import choose_torch_device, torch_dtype @@ -38,7 +38,7 @@ def __init__( self, app_config: InvokeAIAppConfig, logger: Logger, - ram_cache: ModelCacheBase, + ram_cache: ModelCacheBase[AnyModel], convert_cache: ModelConvertCacheBase, ): """Initialize the loader.""" @@ -47,7 +47,6 @@ def __init__( self._ram_cache = ram_cache self._convert_cache = convert_cache self._torch_dtype = torch_dtype(choose_torch_device()) - self._size: Optional[int] = None # model size def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: """ @@ -63,9 +62,7 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo if model_config.type == "main" and not submodel_type: raise InvalidModelConfigException("submodel_type is required when loading a main model") - model_path, is_submodel_override = self._get_model_path(model_config, submodel_type) - if is_submodel_override: - submodel_type = None + model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type) if not model_path.exists(): raise InvalidModelConfigException(f"Files for model 'model_config.name' not found at {model_path}") @@ -74,13 +71,12 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo locker = self._load_if_needed(model_config, model_path, submodel_type) return LoadedModel(config=model_config, locker=locker) - # IMPORTANT: This needs to be overridden in the StableDiffusion subclass so as to handle vae overrides - # and submodels!!!! def _get_model_path( self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None - ) -> Tuple[Path, bool]: + ) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]: model_base = self._app_config.models_path - return ((model_base / config.path).resolve(), False) + result = (model_base / config.path).resolve(), config, submodel_type + return result def _convert_if_needed( self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None @@ -90,7 +86,7 @@ def _convert_if_needed( if not self._needs_conversion(config, model_path, cache_path): return cache_path if cache_path.exists() else model_path - self._convert_cache.make_room(self._size or self.get_size_fs(config, model_path, submodel_type)) + self._convert_cache.make_room(self.get_size_fs(config, model_path, submodel_type)) return self._convert_model(config, model_path, cache_path) def _needs_conversion(self, config: AnyModelConfig, model_path: Path, cache_path: Path) -> bool: @@ -114,6 +110,7 @@ def _load_if_needed( config.key, submodel_type=submodel_type, model=loaded_model, + size=calc_model_size_by_data(loaded_model), ) return self._ram_cache.get(config.key, submodel_type) @@ -128,17 +125,6 @@ def get_size_fs( variant=config.repo_variant if hasattr(config, "repo_variant") else None, ) - def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path: - raise NotImplementedError - - def _load_model( - self, - model_path: Path, - model_variant: Optional[ModelRepoVariant] = None, - submodel_type: Optional[SubModelType] = None, - ) -> AnyModel: - raise NotImplementedError - def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]: return ConfigLoader.load_config(model_path, config_name=config_name) @@ -161,3 +147,17 @@ def _get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelT config = self._load_diffusers_config(model_path, config_name="config.json") class_name = config["_class_name"] return self._hf_definition_to_type(module="diffusers", class_name=class_name) + + # This needs to be implemented in subclasses that handle checkpoints + def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path: + raise NotImplementedError + + # This needs to be implemented in the subclass + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + raise NotImplementedError + diff --git a/invokeai/backend/model_manager/load/memory_snapshot.py b/invokeai/backend/model_manager/load/memory_snapshot.py index 504829a4271..295be0c5514 100644 --- a/invokeai/backend/model_manager/load/memory_snapshot.py +++ b/invokeai/backend/model_manager/load/memory_snapshot.py @@ -97,4 +97,4 @@ def get_msg_line(prefix: str, val1: int, val2: int) -> str: if snapshot_1.vram is not None and snapshot_2.vram is not None: msg += get_msg_line("VRAM", snapshot_1.vram, snapshot_2.vram) - return msg + return "\n"+msg if len(msg)>0 else msg diff --git a/invokeai/backend/model_manager/load/model_cache/__init__.py b/invokeai/backend/model_manager/load/model_cache/__init__.py index 776b9d8936d..50cafa37696 100644 --- a/invokeai/backend/model_manager/load/model_cache/__init__.py +++ b/invokeai/backend/model_manager/load/model_cache/__init__.py @@ -2,4 +2,4 @@ from .model_cache_base import ModelCacheBase from .model_cache_default import ModelCache -_all__ = ['ModelCacheBase', 'ModelCache'] +_all__ = ["ModelCacheBase", "ModelCache"] diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py index 50b69d961c6..14a7dfb4a1f 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py @@ -8,14 +8,15 @@ """ from abc import ABC, abstractmethod -from dataclasses import dataclass, field +from dataclasses import dataclass from logging import Logger -from typing import Dict, Optional, TypeVar, Generic +from typing import Generic, Optional, TypeVar import torch from invokeai.backend.model_manager import AnyModel, SubModelType + class ModelLockerBase(ABC): """Base class for the model locker used by the loader.""" @@ -35,8 +36,10 @@ def model(self) -> AnyModel: """Return the model.""" pass + T = TypeVar("T") + @dataclass class CacheRecord(Generic[T]): """Elements of the cache.""" @@ -115,6 +118,7 @@ def put( self, key: str, model: T, + size: int, submodel_type: Optional[SubModelType] = None, ) -> None: """Store model under key and optional submodel_type.""" diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index 961f68a4bea..688be8ceb48 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -19,22 +19,24 @@ """ import gc +import logging import math +import sys import time from contextlib import suppress from logging import Logger -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional import torch from invokeai.backend.model_manager import SubModelType from invokeai.backend.model_manager.load.load_base import AnyModel from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff -from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data from invokeai.backend.util.devices import choose_torch_device from invokeai.backend.util.logging import InvokeAILogger + from .model_cache_base import CacheRecord, ModelCacheBase -from .model_locker import ModelLockerBase, ModelLocker +from .model_locker import ModelLocker, ModelLockerBase if choose_torch_device() == torch.device("mps"): from torch import mps @@ -91,7 +93,7 @@ def __init__( self._execution_device: torch.device = execution_device self._storage_device: torch.device = storage_device self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__) - self._log_memory_usage = log_memory_usage + self._log_memory_usage = log_memory_usage or self._logger.level == logging.DEBUG self._cached_models: Dict[str, CacheRecord[AnyModel]] = {} self._cache_stack: List[str] = [] @@ -141,14 +143,14 @@ def put( self, key: str, model: AnyModel, + size: int, submodel_type: Optional[SubModelType] = None, ) -> None: """Store model under key and optional submodel_type.""" key = self._make_cache_key(key, submodel_type) assert key not in self._cached_models - loaded_model_size = calc_model_size_by_data(model) - cache_record = CacheRecord(key, model, loaded_model_size) + cache_record = CacheRecord(key, model, size) self._cached_models[key] = cache_record self._cache_stack.append(key) @@ -195,28 +197,32 @@ def offload_unlocked_models(self, size_required: int) -> None: for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size): if vram_in_use <= reserved: break + if not cache_entry.loaded: + continue if not cache_entry.locked: self.move_model_to_device(cache_entry, self.storage_device) cache_entry.loaded = False vram_in_use = torch.cuda.memory_allocated() + size_required - self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM now available for models; max allowed={(reserved/GIG):.2f}GB") + self.logger.debug( + f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB" + ) torch.cuda.empty_cache() if choose_torch_device() == torch.device("mps"): mps.empty_cache() - # TO DO: Only reason to pass the CacheRecord rather than the model is to get the key and size - # for printing debugging messages. Revisit whether this is necessary - def move_model_to_device(self, cache_entry: CacheRecord, target_device: torch.device) -> None: + def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None: """Move model into the indicated device.""" - # These attributes are not in the base class but in derived classes - assert hasattr(cache_entry.model, "device") - assert hasattr(cache_entry.model, "to") + # These attributes are not in the base ModelMixin class but in derived classes. + # Some models don't have these attributes, in which case they run in RAM/CPU. + self.logger.debug(f"Called to move {cache_entry.key} to {target_device}") + if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")): + return source_device = cache_entry.model.device - # Note: We compare device types only so that 'cuda' == 'cuda:0'. This would need to be revised to support - # multi-GPU. + # Note: We compare device types only so that 'cuda' == 'cuda:0'. + # This would need to be revised to support multi-GPU. if torch.device(source_device).type == torch.device(target_device).type: return @@ -227,8 +233,8 @@ def move_model_to_device(self, cache_entry: CacheRecord, target_device: torch.de end_model_to_time = time.time() self.logger.debug( f"Moved model '{cache_entry.key}' from {source_device} to" - f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s.\n" - f"Estimated model size: {(cache_entry.size/GIG):.3f} GB.\n" + f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s." + f"Estimated model size: {(cache_entry.size/GIG):.3f} GB." f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" ) @@ -291,7 +297,7 @@ def make_room(self, model_size: int) -> None: f" {(bytes_needed/GIG):.2f} GB" ) - self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}") + self.logger.debug(f"Before making_room: cached_models={len(self._cached_models)}") pos = 0 models_cleared = 0 @@ -336,7 +342,7 @@ def make_room(self, model_size: int) -> None: # 1 from onnx runtime object if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2): self.logger.debug( - f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)" + f"Removing {model_key} from RAM cache to free at least {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)" ) current_size -= cache_entry.size models_cleared += 1 @@ -365,4 +371,4 @@ def make_room(self, model_size: int) -> None: if choose_torch_device() == torch.device("mps"): mps.empty_cache() - self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}") + self.logger.debug(f"After making room: cached_models={len(self._cached_models)}") diff --git a/invokeai/backend/model_manager/load/model_cache/model_locker.py b/invokeai/backend/model_manager/load/model_cache/model_locker.py index 506d0129491..7a5fdd4284b 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_locker.py +++ b/invokeai/backend/model_manager/load/model_cache/model_locker.py @@ -2,9 +2,10 @@ Base class and implementation of a class that moves models in and out of VRAM. """ -from abc import ABC, abstractmethod from invokeai.backend.model_manager import AnyModel -from .model_cache_base import ModelLockerBase, ModelCacheBase, CacheRecord + +from .model_cache_base import CacheRecord, ModelCacheBase, ModelLockerBase + class ModelLocker(ModelLockerBase): """Internal class that mediates movement in and out of GPU.""" @@ -56,4 +57,3 @@ def unlock(self) -> None: if not self._cache.lazy_offloading: self._cache.offload_unlocked_models(self._cache_entry.size) self._cache.print_cuda_stats() - diff --git a/invokeai/backend/model_manager/load/model_loaders/controlnet.py b/invokeai/backend/model_manager/load/model_loaders/controlnet.py new file mode 100644 index 00000000000..8e6a80ceb20 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/controlnet.py @@ -0,0 +1,60 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Class for ControlNet model loading in InvokeAI.""" + +from pathlib import Path + +import safetensors +import torch + +from invokeai.backend.model_manager import ( + AnyModelConfig, + BaseModelType, + ModelFormat, + ModelType, +) +from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers +from invokeai.backend.model_manager.load.load_base import AnyModelLoader +from .generic_diffusers import GenericDiffusersLoader + +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers) +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint) +class ControlnetLoader(GenericDiffusersLoader): + """Class to load ControlNet models.""" + + def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool: + if config.format != ModelFormat.Checkpoint: + return False + elif ( + dest_path.exists() + and (dest_path / "config.json").stat().st_mtime >= (config.last_modified or 0.0) + and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime + ): + return False + else: + return True + + def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path: + if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}: + raise Exception(f"Vae conversion not supported for model type: {config.base}") + else: + assert hasattr(config, 'config') + config_file = config.config + + if weights_path.suffix == ".safetensors": + checkpoint = safetensors.torch.load_file(weights_path, device="cpu") + else: + checkpoint = torch.load(weights_path, map_location="cpu") + + # sometimes weights are hidden under "state_dict", and sometimes not + if "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + convert_controlnet_to_diffusers( + weights_path, + output_path, + original_config_file=self._app_config.root_path / config_file, + image_size=512, + scan_needed=True, + from_safetensors=weights_path.suffix == ".safetensors", + ) + return output_path diff --git a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py new file mode 100644 index 00000000000..f92a9048c50 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py @@ -0,0 +1,34 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Class for simple diffusers model loading in InvokeAI.""" + +from pathlib import Path +from typing import Optional + +from invokeai.backend.model_manager import ( + AnyModel, + BaseModelType, + ModelFormat, + ModelRepoVariant, + ModelType, + SubModelType, +) +from invokeai.backend.model_manager.load.load_base import AnyModelLoader +from invokeai.backend.model_manager.load.load_default import ModelLoader + +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers) +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers) +class GenericDiffusersLoader(ModelLoader): + """Class to load simple diffusers models.""" + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + model_class = self._get_hf_load_class(model_path) + if submodel_type is not None: + raise Exception(f"There are no submodels in models of type {model_class}") + variant = model_variant.value if model_variant else None + result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant) # type: ignore + return result diff --git a/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py b/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py new file mode 100644 index 00000000000..63dc3790f16 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py @@ -0,0 +1,39 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Class for IP Adapter model loading in InvokeAI.""" + +import torch + +from pathlib import Path +from typing import Optional + +from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter +from invokeai.backend.model_manager import ( + AnyModel, + BaseModelType, + ModelFormat, + ModelRepoVariant, + ModelType, + SubModelType, +) +from invokeai.backend.model_manager.load.load_base import AnyModelLoader +from invokeai.backend.model_manager.load.load_default import ModelLoader + +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI) +class IPAdapterInvokeAILoader(ModelLoader): + """Class to load IP Adapter diffusers models.""" + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if submodel_type is not None: + raise ValueError("There are no submodels in an IP-Adapter model.") + model = build_ip_adapter( + ip_adapter_ckpt_path=model_path / "ip_adapter.bin", + device=torch.device("cpu"), + dtype=self._torch_dtype, + ) + return model + diff --git a/invokeai/backend/model_manager/load/model_loaders/lora.py b/invokeai/backend/model_manager/load/model_loaders/lora.py new file mode 100644 index 00000000000..4d19aadb7d2 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/lora.py @@ -0,0 +1,76 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Class for LoRA model loading in InvokeAI.""" + + +from pathlib import Path +from typing import Optional, Tuple +from logging import Logger + +from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase +from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend.model_manager import ( + AnyModel, + AnyModelConfig, + BaseModelType, + ModelFormat, + ModelRepoVariant, + ModelType, + SubModelType, +) +from invokeai.backend.model_manager.lora import LoRAModelRaw +from invokeai.backend.model_manager.load.load_base import AnyModelLoader +from invokeai.backend.model_manager.load.load_default import ModelLoader + +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Diffusers) +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Lycoris) +class LoraLoader(ModelLoader): + """Class to load LoRA models.""" + + # We cheat a little bit to get access to the model base + def __init__( + self, + app_config: InvokeAIAppConfig, + logger: Logger, + ram_cache: ModelCacheBase[AnyModel], + convert_cache: ModelConvertCacheBase, + ): + """Initialize the loader.""" + super().__init__(app_config, logger, ram_cache, convert_cache) + self._model_base: Optional[BaseModelType] = None + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if submodel_type is not None: + raise ValueError("There are no submodels in a LoRA model.") + model = LoRAModelRaw.from_checkpoint( + file_path=model_path, + dtype=self._torch_dtype, + base_model=self._model_base, + ) + return model + + # override + def _get_model_path( + self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None + ) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]: + self._model_base = config.base # cheating a little - setting this variable for later call to _load_model() + + model_base_path = self._app_config.models_path + model_path = model_base_path / config.path + + if config.format == ModelFormat.Diffusers: + for ext in ["safetensors", "bin"]: # return path to the safetensors file inside the folder + path = model_base_path / config.path / f"pytorch_lora_weights.{ext}" + if path.exists(): + model_path = path + break + + result = model_path.resolve(), config, submodel_type + return result + + diff --git a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py new file mode 100644 index 00000000000..a963e8403b9 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py @@ -0,0 +1,93 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Class for StableDiffusion model loading in InvokeAI.""" + + +from pathlib import Path +from typing import Optional + +from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline + +from invokeai.backend.model_manager import ( + AnyModel, + AnyModelConfig, + BaseModelType, + ModelFormat, + ModelRepoVariant, + ModelType, + ModelVariantType, + SubModelType, +) +from invokeai.backend.model_manager.config import MainCheckpointConfig +from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers +from invokeai.backend.model_manager.load.load_base import AnyModelLoader +from invokeai.backend.model_manager.load.load_default import ModelLoader + + +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Diffusers) +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Checkpoint) +class StableDiffusionDiffusersModel(ModelLoader): + """Class to load main models.""" + + model_base_to_model_type = { + BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder", + BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder", + BaseModelType.StableDiffusionXL: "SDXL", + BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner", + } + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if not submodel_type is not None: + raise Exception("A submodel type must be provided when loading main pipelines.") + load_class = self._get_hf_load_class(model_path, submodel_type) + variant = model_variant.value if model_variant else None + model_path = model_path / submodel_type.value + result: AnyModel = load_class.from_pretrained( + model_path, + torch_dtype=self._torch_dtype, + variant=variant, + ) # type: ignore + return result + + def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool: + if config.format != ModelFormat.Checkpoint: + return False + elif ( + dest_path.exists() + and (dest_path / "model_index.json").stat().st_mtime >= (config.last_modified or 0.0) + and (dest_path / "model_index.json").stat().st_mtime >= model_path.stat().st_mtime + ): + return False + else: + return True + + def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path: + assert isinstance(config, MainCheckpointConfig) + variant = config.variant + base = config.base + pipeline_class = ( + StableDiffusionInpaintPipeline if variant == ModelVariantType.Inpaint else StableDiffusionPipeline + ) + + config_file = config.config + + self._logger.info(f"Converting {weights_path} to diffusers format") + convert_ckpt_to_diffusers( + weights_path, + output_path, + model_type=self.model_base_to_model_type[base], + model_version=base, + model_variant=variant, + original_config_file=self._app_config.root_path / config_file, + extract_ema=True, + scan_needed=True, + pipeline_class=pipeline_class, + from_safetensors=weights_path.suffix == ".safetensors", + precision=self._torch_dtype, + load_safety_checker=False, + ) + return output_path diff --git a/invokeai/backend/model_manager/load/model_loaders/vae.py b/invokeai/backend/model_manager/load/model_loaders/vae.py index 6f21c3d0903..7a35e53459a 100644 --- a/invokeai/backend/model_manager/load/model_loaders/vae.py +++ b/invokeai/backend/model_manager/load/model_loaders/vae.py @@ -2,68 +2,54 @@ """Class for VAE model loading in InvokeAI.""" from pathlib import Path -from typing import Optional -import torch import safetensors -from omegaconf import OmegaConf, DictConfig -from invokeai.backend.util.devices import torch_dtype -from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelRepoVariant, ModelType, SubModelType -from invokeai.backend.model_manager.load.load_base import AnyModelLoader -from invokeai.backend.model_manager.load.load_default import ModelLoader +import torch +from omegaconf import DictConfig, OmegaConf + +from invokeai.backend.model_manager import ( + AnyModelConfig, + BaseModelType, + ModelFormat, + ModelType, +) from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers +from invokeai.backend.model_manager.load.load_base import AnyModelLoader +from .generic_diffusers import GenericDiffusersLoader + @AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers) @AnyModelLoader.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint) @AnyModelLoader.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint) -class VaeDiffusersModel(ModelLoader): +class VaeLoader(GenericDiffusersLoader): """Class to load VAE models.""" - def _load_model( - self, - model_path: Path, - model_variant: Optional[ModelRepoVariant] = None, - submodel_type: Optional[SubModelType] = None, - ) -> AnyModel: - if submodel_type is not None: - raise Exception("There are no submodels in VAEs") - vae_class = self._get_hf_load_class(model_path) - variant = model_variant.value if model_variant else None - result: AnyModel = vae_class.from_pretrained( - model_path, torch_dtype=self._torch_dtype, variant=variant - ) # type: ignore - return result - def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool: - print(f'DEBUG: last_modified={config.last_modified}') - print(f'DEBUG: cache_path={(dest_path / "config.json").stat().st_mtime}') - print(f'DEBUG: model_path={model_path.stat().st_mtime}') if config.format != ModelFormat.Checkpoint: return False - elif dest_path.exists() \ - and (dest_path / "config.json").stat().st_mtime >= config.last_modified \ - and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime: + elif ( + dest_path.exists() + and (dest_path / "config.json").stat().st_mtime >= (config.last_modified or 0.0) + and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime + ): return False else: return True - def _convert_model(self, - config: AnyModelConfig, - weights_path: Path, - output_path: Path - ) -> Path: + def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path: + # TO DO: check whether sdxl VAE models convert. if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}: raise Exception(f"Vae conversion not supported for model type: {config.base}") else: - config_file = 'v1-inference.yaml' if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml" + config_file = ( + "v1-inference.yaml" if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml" + ) if weights_path.suffix == ".safetensors": checkpoint = safetensors.torch.load_file(weights_path, device="cpu") else: checkpoint = torch.load(weights_path, map_location="cpu") - dtype = torch_dtype() - # sometimes weights are hidden under "state_dict", and sometimes not if "state_dict" in checkpoint: checkpoint = checkpoint["state_dict"] @@ -71,13 +57,11 @@ def _convert_model(self, ckpt_config = OmegaConf.load(self._app_config.legacy_conf_path / config_file) assert isinstance(ckpt_config, DictConfig) - print(f'DEBUG: CONVERTIGN') vae_model = convert_ldm_vae_to_diffusers( checkpoint=checkpoint, vae_config=ckpt_config, image_size=512, ) - vae_model.to(dtype) # set precision appropriately - vae_model.save_pretrained(output_path, safe_serialization=True, torch_dtype=dtype) + vae_model.to(self._torch_dtype) # set precision appropriately + vae_model.save_pretrained(output_path, safe_serialization=True) return output_path - diff --git a/invokeai/backend/model_manager/load/model_util.py b/invokeai/backend/model_manager/load/model_util.py index 7c27e66472f..404c88bbbcd 100644 --- a/invokeai/backend/model_manager/load/model_util.py +++ b/invokeai/backend/model_manager/load/model_util.py @@ -8,10 +8,11 @@ import torch from diffusers import DiffusionPipeline +from invokeai.backend.model_manager.config import AnyModel from invokeai.backend.model_manager.onnx_runtime import IAIOnnxRuntimeModel -def calc_model_size_by_data(model: Union[DiffusionPipeline, torch.nn.Module, IAIOnnxRuntimeModel]) -> int: +def calc_model_size_by_data(model: AnyModel) -> int: """Get size of a model in memory in bytes.""" if isinstance(model, DiffusionPipeline): return _calc_pipeline_by_data(model) @@ -50,7 +51,7 @@ def calc_model_size_by_fs(model_path: Path, subfolder: Optional[str] = None, var """Estimate the size of a model on disk in bytes.""" if model_path.is_file(): return model_path.stat().st_size - + if subfolder is not None: model_path = model_path / subfolder diff --git a/invokeai/backend/model_manager/lora.py b/invokeai/backend/model_manager/lora.py new file mode 100644 index 00000000000..4c48de48ec7 --- /dev/null +++ b/invokeai/backend/model_manager/lora.py @@ -0,0 +1,620 @@ +# Copyright (c) 2024 The InvokeAI Development team +"""LoRA model support.""" + +import torch +from safetensors.torch import load_file +from pathlib import Path +from typing import Dict, Optional, Union, List, Tuple +from typing_extensions import Self +from invokeai.backend.model_manager import BaseModelType + +class LoRALayerBase: + # rank: Optional[int] + # alpha: Optional[float] + # bias: Optional[torch.Tensor] + # layer_key: str + + # @property + # def scale(self): + # return self.alpha / self.rank if (self.alpha and self.rank) else 1.0 + + def __init__( + self, + layer_key: str, + values: Dict[str, torch.Tensor], + ): + if "alpha" in values: + self.alpha = values["alpha"].item() + else: + self.alpha = None + + if "bias_indices" in values and "bias_values" in values and "bias_size" in values: + self.bias: Optional[torch.Tensor] = torch.sparse_coo_tensor( + values["bias_indices"], + values["bias_values"], + tuple(values["bias_size"]), + ) + + else: + self.bias = None + + self.rank = None # set in layer implementation + self.layer_key = layer_key + + def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: + raise NotImplementedError() + + def calc_size(self) -> int: + model_size = 0 + for val in [self.bias]: + if val is not None: + model_size += val.nelement() * val.element_size() + return model_size + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> None: + if self.bias is not None: + self.bias = self.bias.to(device=device, dtype=dtype) + + +# TODO: find and debug lora/locon with bias +class LoRALayer(LoRALayerBase): + # up: torch.Tensor + # mid: Optional[torch.Tensor] + # down: torch.Tensor + + def __init__( + self, + layer_key: str, + values: Dict[str, torch.Tensor], + ): + super().__init__(layer_key, values) + + self.up = values["lora_up.weight"] + self.down = values["lora_down.weight"] + if "lora_mid.weight" in values: + self.mid: Optional[torch.Tensor] = values["lora_mid.weight"] + else: + self.mid = None + + self.rank = self.down.shape[0] + + def get_weight(self, orig_weight: torch.Tensor): + if self.mid is not None: + up = self.up.reshape(self.up.shape[0], self.up.shape[1]) + down = self.down.reshape(self.down.shape[0], self.down.shape[1]) + weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down) + else: + weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1) + + return weight + + def calc_size(self) -> int: + model_size = super().calc_size() + for val in [self.up, self.mid, self.down]: + if val is not None: + model_size += val.nelement() * val.element_size() + return model_size + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> None: + super().to(device=device, dtype=dtype) + + self.up = self.up.to(device=device, dtype=dtype) + self.down = self.down.to(device=device, dtype=dtype) + + if self.mid is not None: + self.mid = self.mid.to(device=device, dtype=dtype) + + +class LoHALayer(LoRALayerBase): + # w1_a: torch.Tensor + # w1_b: torch.Tensor + # w2_a: torch.Tensor + # w2_b: torch.Tensor + # t1: Optional[torch.Tensor] = None + # t2: Optional[torch.Tensor] = None + + def __init__( + self, + layer_key: str, + values: Dict[str, torch.Tensor] + ): + super().__init__(layer_key, values) + + self.w1_a = values["hada_w1_a"] + self.w1_b = values["hada_w1_b"] + self.w2_a = values["hada_w2_a"] + self.w2_b = values["hada_w2_b"] + + if "hada_t1" in values: + self.t1: Optional[torch.Tensor] = values["hada_t1"] + else: + self.t1 = None + + if "hada_t2" in values: + self.t2: Optional[torch.Tensor] = values["hada_t2"] + else: + self.t2 = None + + self.rank = self.w1_b.shape[0] + + def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: + if self.t1 is None: + weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b) + + else: + rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a) + rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a) + weight = rebuild1 * rebuild2 + + return weight + + def calc_size(self) -> int: + model_size = super().calc_size() + for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]: + if val is not None: + model_size += val.nelement() * val.element_size() + return model_size + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> None: + super().to(device=device, dtype=dtype) + + self.w1_a = self.w1_a.to(device=device, dtype=dtype) + self.w1_b = self.w1_b.to(device=device, dtype=dtype) + if self.t1 is not None: + self.t1 = self.t1.to(device=device, dtype=dtype) + + self.w2_a = self.w2_a.to(device=device, dtype=dtype) + self.w2_b = self.w2_b.to(device=device, dtype=dtype) + if self.t2 is not None: + self.t2 = self.t2.to(device=device, dtype=dtype) + + +class LoKRLayer(LoRALayerBase): + # w1: Optional[torch.Tensor] = None + # w1_a: Optional[torch.Tensor] = None + # w1_b: Optional[torch.Tensor] = None + # w2: Optional[torch.Tensor] = None + # w2_a: Optional[torch.Tensor] = None + # w2_b: Optional[torch.Tensor] = None + # t2: Optional[torch.Tensor] = None + + def __init__( + self, + layer_key: str, + values: Dict[str, torch.Tensor], + ): + super().__init__(layer_key, values) + + if "lokr_w1" in values: + self.w1: Optional[torch.Tensor] = values["lokr_w1"] + self.w1_a = None + self.w1_b = None + else: + self.w1 = None + self.w1_a = values["lokr_w1_a"] + self.w1_b = values["lokr_w1_b"] + + if "lokr_w2" in values: + self.w2: Optional[torch.Tensor] = values["lokr_w2"] + self.w2_a = None + self.w2_b = None + else: + self.w2 = None + self.w2_a = values["lokr_w2_a"] + self.w2_b = values["lokr_w2_b"] + + if "lokr_t2" in values: + self.t2: Optional[torch.Tensor] = values["lokr_t2"] + else: + self.t2 = None + + if "lokr_w1_b" in values: + self.rank = values["lokr_w1_b"].shape[0] + elif "lokr_w2_b" in values: + self.rank = values["lokr_w2_b"].shape[0] + else: + self.rank = None # unscaled + + def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: + w1: Optional[torch.Tensor] = self.w1 + if w1 is None: + assert self.w1_a is not None + assert self.w1_b is not None + w1 = self.w1_a @ self.w1_b + + w2 = self.w2 + if w2 is None: + if self.t2 is None: + assert self.w2_a is not None + assert self.w2_b is not None + w2 = self.w2_a @ self.w2_b + else: + w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b) + + if len(w2.shape) == 4: + w1 = w1.unsqueeze(2).unsqueeze(2) + w2 = w2.contiguous() + assert w1 is not None + assert w2 is not None + weight = torch.kron(w1, w2) + + return weight + + def calc_size(self) -> int: + model_size = super().calc_size() + for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]: + if val is not None: + model_size += val.nelement() * val.element_size() + return model_size + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> None: + super().to(device=device, dtype=dtype) + + if self.w1 is not None: + self.w1 = self.w1.to(device=device, dtype=dtype) + else: + assert self.w1_a is not None + assert self.w1_b is not None + self.w1_a = self.w1_a.to(device=device, dtype=dtype) + self.w1_b = self.w1_b.to(device=device, dtype=dtype) + + if self.w2 is not None: + self.w2 = self.w2.to(device=device, dtype=dtype) + else: + assert self.w2_a is not None + assert self.w2_b is not None + self.w2_a = self.w2_a.to(device=device, dtype=dtype) + self.w2_b = self.w2_b.to(device=device, dtype=dtype) + + if self.t2 is not None: + self.t2 = self.t2.to(device=device, dtype=dtype) + + +class FullLayer(LoRALayerBase): + # weight: torch.Tensor + + def __init__( + self, + layer_key: str, + values: Dict[str, torch.Tensor], + ): + super().__init__(layer_key, values) + + self.weight = values["diff"] + + if len(values.keys()) > 1: + _keys = list(values.keys()) + _keys.remove("diff") + raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}") + + self.rank = None # unscaled + + def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: + return self.weight + + def calc_size(self) -> int: + model_size = super().calc_size() + model_size += self.weight.nelement() * self.weight.element_size() + return model_size + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + super().to(device=device, dtype=dtype) + + self.weight = self.weight.to(device=device, dtype=dtype) + + +class IA3Layer(LoRALayerBase): + # weight: torch.Tensor + # on_input: torch.Tensor + + def __init__( + self, + layer_key: str, + values: Dict[str, torch.Tensor], + ): + super().__init__(layer_key, values) + + self.weight = values["weight"] + self.on_input = values["on_input"] + + self.rank = None # unscaled + + def get_weight(self, orig_weight: torch.Tensor): + weight = self.weight + if not self.on_input: + weight = weight.reshape(-1, 1) + return orig_weight * weight + + def calc_size(self) -> int: + model_size = super().calc_size() + model_size += self.weight.nelement() * self.weight.element_size() + model_size += self.on_input.nelement() * self.on_input.element_size() + return model_size + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + super().to(device=device, dtype=dtype) + + self.weight = self.weight.to(device=device, dtype=dtype) + self.on_input = self.on_input.to(device=device, dtype=dtype) + +AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer] + +# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix +class LoRAModelRaw: # (torch.nn.Module): + _name: str + layers: Dict[str, AnyLoRALayer] + + def __init__( + self, + name: str, + layers: Dict[str, AnyLoRALayer], + ): + self._name = name + self.layers = layers + + @property + def name(self) -> str: + return self._name + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> None: + # TODO: try revert if exception? + for _key, layer in self.layers.items(): + layer.to(device=device, dtype=dtype) + + def calc_size(self) -> int: + model_size = 0 + for _, layer in self.layers.items(): + model_size += layer.calc_size() + return model_size + + @classmethod + def _convert_sdxl_keys_to_diffusers_format(cls, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """Convert the keys of an SDXL LoRA state_dict to diffusers format. + + The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in + diffusers format, then this function will have no effect. + + This function is adapted from: + https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409 + + Args: + state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict. + + Raises: + ValueError: If state_dict contains an unrecognized key, or not all keys could be converted. + + Returns: + Dict[str, Tensor]: The diffusers-format state_dict. + """ + converted_count = 0 # The number of Stability AI keys converted to diffusers format. + not_converted_count = 0 # The number of keys that were not converted. + + # Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes. + # For example, we want to efficiently find `input_blocks_4_1` in the list when searching for + # `input_blocks_4_1_proj_in`. + stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP) + stability_unet_keys.sort() + + new_state_dict = {} + for full_key, value in state_dict.items(): + if full_key.startswith("lora_unet_"): + search_key = full_key.replace("lora_unet_", "") + # Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix. + position = bisect.bisect_right(stability_unet_keys, search_key) + map_key = stability_unet_keys[position - 1] + # Now, check if the map_key *actually* matches the search_key. + if search_key.startswith(map_key): + new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key]) + new_state_dict[new_key] = value + converted_count += 1 + else: + new_state_dict[full_key] = value + not_converted_count += 1 + elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"): + # The CLIP text encoders have the same keys in both Stability AI and diffusers formats. + new_state_dict[full_key] = value + continue + else: + raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.") + + if converted_count > 0 and not_converted_count > 0: + raise ValueError( + f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count}," + f" not_converted={not_converted_count}" + ) + + return new_state_dict + + @classmethod + def from_checkpoint( + cls, + file_path: Union[str, Path], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + base_model: Optional[BaseModelType] = None, + ) -> Self: + device = device or torch.device("cpu") + dtype = dtype or torch.float32 + + if isinstance(file_path, str): + file_path = Path(file_path) + + model = cls( + name=file_path.stem, # TODO: + layers={}, + ) + + if file_path.suffix == ".safetensors": + state_dict = load_file(file_path.absolute().as_posix(), device="cpu") + else: + state_dict = torch.load(file_path, map_location="cpu") + + state_dict = cls._group_state(state_dict) + + if base_model == BaseModelType.StableDiffusionXL: + state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict) + + for layer_key, values in state_dict.items(): + # lora and locon + if "lora_down.weight" in values: + layer: AnyLoRALayer = LoRALayer(layer_key, values) + + # loha + elif "hada_w1_b" in values: + layer = LoHALayer(layer_key, values) + + # lokr + elif "lokr_w1_b" in values or "lokr_w1" in values: + layer = LoKRLayer(layer_key, values) + + # diff + elif "diff" in values: + layer = FullLayer(layer_key, values) + + # ia3 + elif "weight" in values and "on_input" in values: + layer = IA3Layer(layer_key, values) + + else: + print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}") + raise Exception("Unknown lora format!") + + # lower memory consumption by removing already parsed layer values + state_dict[layer_key].clear() + + layer.to(device=device, dtype=dtype) + model.layers[layer_key] = layer + + return model + + @staticmethod + def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]: + state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {} + + for key, value in state_dict.items(): + stem, leaf = key.split(".", 1) + if stem not in state_dict_groupped: + state_dict_groupped[stem] = {} + state_dict_groupped[stem][leaf] = value + + return state_dict_groupped + + +# code from +# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32 +def make_sdxl_unet_conversion_map() -> List[Tuple[str,str]]: + """Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format.""" + unet_conversion_map_layer = [] + + for i in range(3): # num_blocks is 3 in sdxl + # loop over downblocks/upblocks + for j in range(2): + # loop over resnets/attentions for downblocks + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." + sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + + if i < 3: + # no attention layers in down_blocks.3 + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + + for j in range(3): + # loop over resnets/attentions for upblocks + hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." + sd_up_res_prefix = f"output_blocks.{3*i + j}.0." + unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) + + # if i > 0: commentout for sdxl + # no attention layers in up_blocks.0 + hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." + sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." + unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) + + if i < 3: + # no downsample in down_blocks.3 + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." + sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + + # no upsample in up_blocks.3 + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl + unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) + + hf_mid_atn_prefix = "mid_block.attentions.0." + sd_mid_atn_prefix = "middle_block.1." + unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + + for j in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{j}." + sd_mid_res_prefix = f"middle_block.{2*j}." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + unet_conversion_map_resnet = [ + # (stable-diffusion, HF Diffusers) + ("in_layers.0.", "norm1."), + ("in_layers.2.", "conv1."), + ("out_layers.0.", "norm2."), + ("out_layers.3.", "conv2."), + ("emb_layers.1.", "time_emb_proj."), + ("skip_connection.", "conv_shortcut."), + ] + + unet_conversion_map = [] + for sd, hf in unet_conversion_map_layer: + if "resnets" in hf: + for sd_res, hf_res in unet_conversion_map_resnet: + unet_conversion_map.append((sd + sd_res, hf + hf_res)) + else: + unet_conversion_map.append((sd, hf)) + + for j in range(2): + hf_time_embed_prefix = f"time_embedding.linear_{j+1}." + sd_time_embed_prefix = f"time_embed.{j*2}." + unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix)) + + for j in range(2): + hf_label_embed_prefix = f"add_embedding.linear_{j+1}." + sd_label_embed_prefix = f"label_emb.0.{j*2}." + unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix)) + + unet_conversion_map.append(("input_blocks.0.0.", "conv_in.")) + unet_conversion_map.append(("out.0.", "conv_norm_out.")) + unet_conversion_map.append(("out.2.", "conv_out.")) + + return unet_conversion_map + + +SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = { + sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map() +} diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 9fd118b7822..64a20a20923 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -29,8 +29,12 @@ LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[SchedulerPredictionType, str]]]] = { BaseModelType.StableDiffusion1: { - ModelVariantType.Normal: "v1-inference.yaml", + ModelVariantType.Normal: { + SchedulerPredictionType.Epsilon: "v1-inference.yaml", + SchedulerPredictionType.VPrediction: "v1-inference-v.yaml", + }, ModelVariantType.Inpaint: "v1-inpainting-inference.yaml", + ModelVariantType.Depth: "v2-midas-inference.yaml", }, BaseModelType.StableDiffusion2: { ModelVariantType.Normal: { diff --git a/invokeai/backend/util/__init__.py b/invokeai/backend/util/__init__.py index 0164dffe303..7b48f0364ea 100644 --- a/invokeai/backend/util/__init__.py +++ b/invokeai/backend/util/__init__.py @@ -12,14 +12,22 @@ torch_dtype, ) from .logging import InvokeAILogger -from .util import ( # TO DO: Clean this up; remove the unused symbols +from .util import ( # TO DO: Clean this up; remove the unused symbols GIG, Chdir, ask_user, # noqa directory_size, download_with_resume, - instantiate_from_config, # noqa + instantiate_from_config, # noqa url_attachment_name, # noqa - ) +) -__all__ = ["GIG", "directory_size","Chdir", "download_with_resume", "InvokeAILogger", "choose_precision", "choose_torch_device"] +__all__ = [ + "GIG", + "directory_size", + "Chdir", + "download_with_resume", + "InvokeAILogger", + "choose_precision", + "choose_torch_device", +] diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index ad3f4e139a7..a787f9b6f42 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -1,7 +1,7 @@ from __future__ import annotations from contextlib import nullcontext -from typing import Union, Optional +from typing import Optional, Union import torch from torch import autocast diff --git a/invokeai/backend/util/util.py b/invokeai/backend/util/util.py index 6589aa72784..ae376b41b25 100644 --- a/invokeai/backend/util/util.py +++ b/invokeai/backend/util/util.py @@ -27,6 +27,7 @@ # actual size of a gig GIG = 1073741824 + def directory_size(directory: Path) -> int: """ Return the aggregate size of all files in a directory (bytes). @@ -39,6 +40,7 @@ def directory_size(directory: Path) -> int: sum += Path(root, d).stat().st_size return sum + def log_txt_as_img(wh, xc, size=10): # wh a tuple of (width, height) # xc a list of captions to plot From ad2926a24c4b72ef31cd3afc847b403b5f65046e Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 4 Feb 2024 23:18:00 -0500 Subject: [PATCH 076/100] added textual inversion and lora loaders --- .../model_install/model_install_default.py | 5 + .../{model_manager => embeddings}/lora.py | 35 +- invokeai/backend/embeddings/model_patcher.py | 586 ++++++++++++++++++ invokeai/backend/model_management/lora.py | 5 +- invokeai/backend/model_manager/config.py | 4 +- .../model_manager/load/load_default.py | 11 +- .../model_manager/load/memory_snapshot.py | 2 +- .../load/model_cache/__init__.py | 2 - .../load/model_loaders/controlnet.py | 4 +- .../load/model_loaders/generic_diffusers.py | 1 + .../load/model_loaders/ip_adapter.py | 6 +- .../model_manager/load/model_loaders/lora.py | 18 +- .../load/model_loaders/textual_inversion.py | 55 ++ .../model_manager/load/model_loaders/vae.py | 1 + .../backend/model_manager/load/model_util.py | 4 +- .../{model_manager => onnx}/onnx_runtime.py | 0 16 files changed, 701 insertions(+), 38 deletions(-) rename invokeai/backend/{model_manager => embeddings}/lora.py (96%) create mode 100644 invokeai/backend/embeddings/model_patcher.py create mode 100644 invokeai/backend/model_manager/load/model_loaders/textual_inversion.py rename invokeai/backend/{model_manager => onnx}/onnx_runtime.py (100%) diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 2b2294bfce4..1c188b300df 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -178,6 +178,11 @@ def install_path( ) def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102 + similar_jobs = [x for x in self.list_jobs() if x.source == source and not x.in_terminal_state] + if similar_jobs: + self._logger.warning(f"There is already an active install job for {source}. Not enqueuing.") + return similar_jobs[0] + if isinstance(source, LocalModelSource): install_job = self._import_local_model(source, config) self._install_queue.put(install_job) # synchronously install diff --git a/invokeai/backend/model_manager/lora.py b/invokeai/backend/embeddings/lora.py similarity index 96% rename from invokeai/backend/model_manager/lora.py rename to invokeai/backend/embeddings/lora.py index 4c48de48ec7..9a59a977087 100644 --- a/invokeai/backend/model_manager/lora.py +++ b/invokeai/backend/embeddings/lora.py @@ -1,13 +1,17 @@ # Copyright (c) 2024 The InvokeAI Development team """LoRA model support.""" +import bisect +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + import torch from safetensors.torch import load_file -from pathlib import Path -from typing import Dict, Optional, Union, List, Tuple from typing_extensions import Self + from invokeai.backend.model_manager import BaseModelType + class LoRALayerBase: # rank: Optional[int] # alpha: Optional[float] @@ -41,7 +45,7 @@ def __init__( self.rank = None # set in layer implementation self.layer_key = layer_key - def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: + def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: raise NotImplementedError() def calc_size(self) -> int: @@ -82,7 +86,7 @@ def __init__( self.rank = self.down.shape[0] - def get_weight(self, orig_weight: torch.Tensor): + def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: if self.mid is not None: up = self.up.reshape(self.up.shape[0], self.up.shape[1]) down = self.down.reshape(self.down.shape[0], self.down.shape[1]) @@ -121,11 +125,7 @@ class LoHALayer(LoRALayerBase): # t1: Optional[torch.Tensor] = None # t2: Optional[torch.Tensor] = None - def __init__( - self, - layer_key: str, - values: Dict[str, torch.Tensor] - ): + def __init__(self, layer_key: str, values: Dict[str, torch.Tensor]): super().__init__(layer_key, values) self.w1_a = values["hada_w1_a"] @@ -145,7 +145,7 @@ def __init__( self.rank = self.w1_b.shape[0] - def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: + def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: if self.t1 is None: weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b) @@ -227,7 +227,7 @@ def __init__( else: self.rank = None # unscaled - def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: + def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: w1: Optional[torch.Tensor] = self.w1 if w1 is None: assert self.w1_a is not None @@ -305,7 +305,7 @@ def __init__( self.rank = None # unscaled - def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: + def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: return self.weight def calc_size(self) -> int: @@ -330,7 +330,7 @@ class IA3Layer(LoRALayerBase): def __init__( self, layer_key: str, - values: Dict[str, torch.Tensor], + values: Dict[str, torch.Tensor], ): super().__init__(layer_key, values) @@ -339,10 +339,11 @@ def __init__( self.rank = None # unscaled - def get_weight(self, orig_weight: torch.Tensor): + def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: weight = self.weight if not self.on_input: weight = weight.reshape(-1, 1) + assert orig_weight is not None return orig_weight * weight def calc_size(self) -> int: @@ -361,8 +362,10 @@ def to( self.weight = self.weight.to(device=device, dtype=dtype) self.on_input = self.on_input.to(device=device, dtype=dtype) + AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer] - + + # TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix class LoRAModelRaw: # (torch.nn.Module): _name: str @@ -530,7 +533,7 @@ def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, tor # code from # https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32 -def make_sdxl_unet_conversion_map() -> List[Tuple[str,str]]: +def make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]: """Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format.""" unet_conversion_map_layer = [] diff --git a/invokeai/backend/embeddings/model_patcher.py b/invokeai/backend/embeddings/model_patcher.py new file mode 100644 index 00000000000..6d73235197d --- /dev/null +++ b/invokeai/backend/embeddings/model_patcher.py @@ -0,0 +1,586 @@ +# Copyright (c) 2024 Ryan Dick, Lincoln D. Stein, and the InvokeAI Development Team +"""These classes implement model patching with LoRAs and Textual Inversions.""" +from __future__ import annotations + +import pickle +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Dict, Generator, List, Optional, Tuple, Union + +import numpy as np +import torch +from compel.embeddings_provider import BaseTextualInversionManager +from diffusers import ModelMixin, OnnxRuntimeModel, UNet2DConditionModel +from safetensors.torch import load_file +from transformers import CLIPTextModel, CLIPTokenizer +from typing_extensions import Self + +from invokeai.app.shared.models import FreeUConfig +from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init +from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel + +from .lora import LoRAModelRaw + +""" +loras = [ + (lora_model1, 0.7), + (lora_model2, 0.4), +] +with LoRAHelper.apply_lora_unet(unet, loras): + # unet with applied loras +# unmodified unet + +""" + + +# TODO: rename smth like ModelPatcher and add TI method? +class ModelPatcher: + @staticmethod + def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]: + assert "." not in lora_key + + if not lora_key.startswith(prefix): + raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}") + + module = model + module_key = "" + key_parts = lora_key[len(prefix) :].split("_") + + submodule_name = key_parts.pop(0) + + while len(key_parts) > 0: + try: + module = module.get_submodule(submodule_name) + module_key += "." + submodule_name + submodule_name = key_parts.pop(0) + except Exception: + submodule_name += "_" + key_parts.pop(0) + + module = module.get_submodule(submodule_name) + module_key = (module_key + "." + submodule_name).lstrip(".") + + return (module_key, module) + + @classmethod + @contextmanager + def apply_lora_unet( + cls, + unet: UNet2DConditionModel, + loras: List[Tuple[LoRAModelRaw, float]], + ) -> Generator[None, None, None]: + with cls.apply_lora(unet, loras, "lora_unet_"): + yield + + @classmethod + @contextmanager + def apply_lora_text_encoder( + cls, + text_encoder: CLIPTextModel, + loras: List[Tuple[LoRAModelRaw, float]], + ): + with cls.apply_lora(text_encoder, loras, "lora_te_"): + yield + + @classmethod + @contextmanager + def apply_sdxl_lora_text_encoder( + cls, + text_encoder: CLIPTextModel, + loras: List[Tuple[LoRAModelRaw, float]], + ): + with cls.apply_lora(text_encoder, loras, "lora_te1_"): + yield + + @classmethod + @contextmanager + def apply_sdxl_lora_text_encoder2( + cls, + text_encoder: CLIPTextModel, + loras: List[Tuple[LoRAModelRaw, float]], + ): + with cls.apply_lora(text_encoder, loras, "lora_te2_"): + yield + + @classmethod + @contextmanager + def apply_lora( + cls, + model: Union[torch.nn.Module, ModelMixin, UNet2DConditionModel], + loras: List[Tuple[LoRAModelRaw, float]], + prefix: str, + ) -> Generator[None, None, None]: + original_weights = {} + try: + with torch.no_grad(): + for lora, lora_weight in loras: + # assert lora.device.type == "cpu" + for layer_key, layer in lora.layers.items(): + if not layer_key.startswith(prefix): + continue + + # TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This + # should be improved in the following ways: + # 1. The key mapping could be more-efficiently pre-computed. This would save time every time a + # LoRA model is applied. + # 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the + # intricacies of Stable Diffusion key resolution. It should just expect the input LoRA + # weights to have valid keys. + module_key, module = cls._resolve_lora_key(model, layer_key, prefix) + + # All of the LoRA weight calculations will be done on the same device as the module weight. + # (Performance will be best if this is a CUDA device.) + device = module.weight.device + dtype = module.weight.dtype + + if module_key not in original_weights: + original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True) + + layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 + + # We intentionally move to the target device first, then cast. Experimentally, this was found to + # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the + # same thing in a single call to '.to(...)'. + layer.to(device=device) + layer.to(dtype=torch.float32) + # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA + # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed. + layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale) + layer.to(device=torch.device("cpu")) + + assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! + if module.weight.shape != layer_weight.shape: + # TODO: debug on lycoris + assert hasattr(layer_weight, "reshape") + layer_weight = layer_weight.reshape(module.weight.shape) + + assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! + module.weight += layer_weight.to(dtype=dtype) + + yield # wait for context manager exit + + finally: + assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule() + with torch.no_grad(): + for module_key, weight in original_weights.items(): + model.get_submodule(module_key).weight.copy_(weight) + + @classmethod + @contextmanager + def apply_ti( + cls, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + ti_list: List[Tuple[str, TextualInversionModel]], + ) -> Generator[Tuple[CLIPTokenizer, TextualInversionManager], None, None]: + init_tokens_count = None + new_tokens_added = None + + # TODO: This is required since Transformers 4.32 see + # https://github.com/huggingface/transformers/pull/25088 + # More information by NVIDIA: + # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc + # This value might need to be changed in the future and take the GPUs model into account as there seem + # to be ideal values for different GPUS. This value is temporary! + # For references to the current discussion please see https://github.com/invoke-ai/InvokeAI/pull/4817 + pad_to_multiple_of = 8 + + try: + # HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a + # workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after + # exiting this `apply_ti(...)` context manager. + # + # In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`, + # but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs). + ti_tokenizer = pickle.loads(pickle.dumps(tokenizer)) + ti_manager = TextualInversionManager(ti_tokenizer) + init_tokens_count = text_encoder.resize_token_embeddings(None, pad_to_multiple_of).num_embeddings + + def _get_trigger(ti_name: str, index: int) -> str: + trigger = ti_name + if index > 0: + trigger += f"-!pad-{i}" + return f"<{trigger}>" + + def _get_ti_embedding(model_embeddings: torch.nn.Module, ti: TextualInversionModel) -> torch.Tensor: + # for SDXL models, select the embedding that matches the text encoder's dimensions + if ti.embedding_2 is not None: + return ( + ti.embedding_2 + if ti.embedding_2.shape[1] == model_embeddings.weight.data[0].shape[0] + else ti.embedding + ) + else: + return ti.embedding + + # modify tokenizer + new_tokens_added = 0 + for ti_name, ti in ti_list: + ti_embedding = _get_ti_embedding(text_encoder.get_input_embeddings(), ti) + + for i in range(ti_embedding.shape[0]): + new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i)) + + # Modify text_encoder. + # resize_token_embeddings(...) constructs a new torch.nn.Embedding internally. Initializing the weights of + # this embedding is slow and unnecessary, so we wrap this step in skip_torch_weight_init() to save some + # time. + with skip_torch_weight_init(): + text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added, pad_to_multiple_of) + model_embeddings = text_encoder.get_input_embeddings() + + for ti_name, ti in ti_list: + ti_embedding = _get_ti_embedding(text_encoder.get_input_embeddings(), ti) + + ti_tokens = [] + for i in range(ti_embedding.shape[0]): + embedding = ti_embedding[i] + trigger = _get_trigger(ti_name, i) + + token_id = ti_tokenizer.convert_tokens_to_ids(trigger) + if token_id == ti_tokenizer.unk_token_id: + raise RuntimeError(f"Unable to find token id for token '{trigger}'") + + if model_embeddings.weight.data[token_id].shape != embedding.shape: + raise ValueError( + f"Cannot load embedding for {trigger}. It was trained on a model with token dimension" + f" {embedding.shape[0]}, but the current model has token dimension" + f" {model_embeddings.weight.data[token_id].shape[0]}." + ) + + model_embeddings.weight.data[token_id] = embedding.to( + device=text_encoder.device, dtype=text_encoder.dtype + ) + ti_tokens.append(token_id) + + if len(ti_tokens) > 1: + ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:] + + yield ti_tokenizer, ti_manager + + finally: + if init_tokens_count and new_tokens_added: + text_encoder.resize_token_embeddings(init_tokens_count, pad_to_multiple_of) + + @classmethod + @contextmanager + def apply_clip_skip( + cls, + text_encoder: CLIPTextModel, + clip_skip: int, + ) -> Generator[None, None, None]: + skipped_layers = [] + try: + for _i in range(clip_skip): + skipped_layers.append(text_encoder.text_model.encoder.layers.pop(-1)) + + yield + + finally: + while len(skipped_layers) > 0: + text_encoder.text_model.encoder.layers.append(skipped_layers.pop()) + + @classmethod + @contextmanager + def apply_freeu( + cls, + unet: UNet2DConditionModel, + freeu_config: Optional[FreeUConfig] = None, + ) -> Generator[None, None, None]: + did_apply_freeu = False + try: + assert hasattr(unet, "enable_freeu") # mypy doesn't pick up this attribute? + if freeu_config is not None: + unet.enable_freeu(b1=freeu_config.b1, b2=freeu_config.b2, s1=freeu_config.s1, s2=freeu_config.s2) + did_apply_freeu = True + + yield + + finally: + assert hasattr(unet, "disable_freeu") # mypy doesn't pick up this attribute? + if did_apply_freeu: + unet.disable_freeu() + + +class TextualInversionModel: + embedding: torch.Tensor # [n, 768]|[n, 1280] + embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models + + @classmethod + def from_checkpoint( + cls, + file_path: Union[str, Path], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> Self: + if not isinstance(file_path, Path): + file_path = Path(file_path) + + result = cls() # TODO: + + if file_path.suffix == ".safetensors": + state_dict = load_file(file_path.absolute().as_posix(), device="cpu") + else: + state_dict = torch.load(file_path, map_location="cpu") + + # both v1 and v2 format embeddings + # difference mostly in metadata + if "string_to_param" in state_dict: + if len(state_dict["string_to_param"]) > 1: + print( + f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first', + " token will be used.", + ) + + result.embedding = next(iter(state_dict["string_to_param"].values())) + + # v3 (easynegative) + elif "emb_params" in state_dict: + result.embedding = state_dict["emb_params"] + + # v5(sdxl safetensors file) + elif "clip_g" in state_dict and "clip_l" in state_dict: + result.embedding = state_dict["clip_g"] + result.embedding_2 = state_dict["clip_l"] + + # v4(diffusers bin files) + else: + result.embedding = next(iter(state_dict.values())) + + if len(result.embedding.shape) == 1: + result.embedding = result.embedding.unsqueeze(0) + + if not isinstance(result.embedding, torch.Tensor): + raise ValueError(f"Invalid embeddings file: {file_path.name}") + + return result + + +# no type hints for BaseTextualInversionManager? +class TextualInversionManager(BaseTextualInversionManager): # type: ignore + pad_tokens: Dict[int, List[int]] + tokenizer: CLIPTokenizer + + def __init__(self, tokenizer: CLIPTokenizer): + self.pad_tokens = {} + self.tokenizer = tokenizer + + def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]: + if len(self.pad_tokens) == 0: + return token_ids + + if token_ids[0] == self.tokenizer.bos_token_id: + raise ValueError("token_ids must not start with bos_token_id") + if token_ids[-1] == self.tokenizer.eos_token_id: + raise ValueError("token_ids must not end with eos_token_id") + + new_token_ids = [] + for token_id in token_ids: + new_token_ids.append(token_id) + if token_id in self.pad_tokens: + new_token_ids.extend(self.pad_tokens[token_id]) + + # Do not exceed the max model input size + # The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(), + # which first removes and then adds back the start and end tokens. + max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2 + if len(new_token_ids) > max_length: + new_token_ids = new_token_ids[0:max_length] + + return new_token_ids + + +class ONNXModelPatcher: + @classmethod + @contextmanager + def apply_lora_unet( + cls, + unet: OnnxRuntimeModel, + loras: List[Tuple[LoRAModelRaw, float]], + ) -> Generator[None, None, None]: + with cls.apply_lora(unet, loras, "lora_unet_"): + yield + + @classmethod + @contextmanager + def apply_lora_text_encoder( + cls, + text_encoder: OnnxRuntimeModel, + loras: List[Tuple[LoRAModelRaw, float]], + ) -> Generator[None, None, None]: + with cls.apply_lora(text_encoder, loras, "lora_te_"): + yield + + # based on + # https://github.com/ssube/onnx-web/blob/ca2e436f0623e18b4cfe8a0363fcfcf10508acf7/api/onnx_web/convert/diffusion/lora.py#L323 + @classmethod + @contextmanager + def apply_lora( + cls, + model: IAIOnnxRuntimeModel, + loras: List[Tuple[LoRAModelRaw, float]], + prefix: str, + ) -> Generator[None, None, None]: + from .models.base import IAIOnnxRuntimeModel + + if not isinstance(model, IAIOnnxRuntimeModel): + raise Exception("Only IAIOnnxRuntimeModel models supported") + + orig_weights = {} + + try: + blended_loras: Dict[str, torch.Tensor] = {} + + for lora, lora_weight in loras: + for layer_key, layer in lora.layers.items(): + if not layer_key.startswith(prefix): + continue + + layer.to(dtype=torch.float32) + layer_key = layer_key.replace(prefix, "") + # TODO: rewrite to pass original tensor weight(required by ia3) + layer_weight = layer.get_weight(None).detach().cpu().numpy() * lora_weight + if layer_key in blended_loras: + blended_loras[layer_key] += layer_weight + else: + blended_loras[layer_key] = layer_weight + + node_names = {} + for node in model.nodes.values(): + node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name + + for layer_key, lora_weight in blended_loras.items(): + conv_key = layer_key + "_Conv" + gemm_key = layer_key + "_Gemm" + matmul_key = layer_key + "_MatMul" + + if conv_key in node_names or gemm_key in node_names: + if conv_key in node_names: + conv_node = model.nodes[node_names[conv_key]] + else: + conv_node = model.nodes[node_names[gemm_key]] + + weight_name = [n for n in conv_node.input if ".weight" in n][0] + orig_weight = model.tensors[weight_name] + + if orig_weight.shape[-2:] == (1, 1): + if lora_weight.shape[-2:] == (1, 1): + new_weight = orig_weight.squeeze((3, 2)) + lora_weight.squeeze((3, 2)) + else: + new_weight = orig_weight.squeeze((3, 2)) + lora_weight + + new_weight = np.expand_dims(new_weight, (2, 3)) + else: + if orig_weight.shape != lora_weight.shape: + new_weight = orig_weight + lora_weight.reshape(orig_weight.shape) + else: + new_weight = orig_weight + lora_weight + + orig_weights[weight_name] = orig_weight + model.tensors[weight_name] = new_weight.astype(orig_weight.dtype) + + elif matmul_key in node_names: + weight_node = model.nodes[node_names[matmul_key]] + matmul_name = [n for n in weight_node.input if "MatMul" in n][0] + + orig_weight = model.tensors[matmul_name] + new_weight = orig_weight + lora_weight.transpose() + + orig_weights[matmul_name] = orig_weight + model.tensors[matmul_name] = new_weight.astype(orig_weight.dtype) + + else: + # warn? err? + pass + + yield + + finally: + # restore original weights + for name, orig_weight in orig_weights.items(): + model.tensors[name] = orig_weight + + @classmethod + @contextmanager + def apply_ti( + cls, + tokenizer: CLIPTokenizer, + text_encoder: IAIOnnxRuntimeModel, + ti_list: List[Tuple[str, Any]], + ) -> Generator[Tuple[CLIPTokenizer, TextualInversionManager], None, None]: + from .models.base import IAIOnnxRuntimeModel + + if not isinstance(text_encoder, IAIOnnxRuntimeModel): + raise Exception("Only IAIOnnxRuntimeModel models supported") + + orig_embeddings = None + + try: + # HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a + # workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after + # exiting this `apply_ti(...)` context manager. + # + # In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`, + # but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs). + ti_tokenizer = pickle.loads(pickle.dumps(tokenizer)) + ti_manager = TextualInversionManager(ti_tokenizer) + + def _get_trigger(ti_name: str, index: int) -> str: + trigger = ti_name + if index > 0: + trigger += f"-!pad-{i}" + return f"<{trigger}>" + + # modify text_encoder + orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"] + + # modify tokenizer + new_tokens_added = 0 + for ti_name, ti in ti_list: + if ti.embedding_2 is not None: + ti_embedding = ( + ti.embedding_2 if ti.embedding_2.shape[1] == orig_embeddings.shape[0] else ti.embedding + ) + else: + ti_embedding = ti.embedding + + for i in range(ti_embedding.shape[0]): + new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i)) + + embeddings = np.concatenate( + (np.copy(orig_embeddings), np.zeros((new_tokens_added, orig_embeddings.shape[1]))), + axis=0, + ) + + for ti_name, _ in ti_list: + ti_tokens = [] + for i in range(ti_embedding.shape[0]): + embedding = ti_embedding[i].detach().numpy() + trigger = _get_trigger(ti_name, i) + + token_id = ti_tokenizer.convert_tokens_to_ids(trigger) + if token_id == ti_tokenizer.unk_token_id: + raise RuntimeError(f"Unable to find token id for token '{trigger}'") + + if embeddings[token_id].shape != embedding.shape: + raise ValueError( + f"Cannot load embedding for {trigger}. It was trained on a model with token dimension" + f" {embedding.shape[0]}, but the current model has token dimension" + f" {embeddings[token_id].shape[0]}." + ) + + embeddings[token_id] = embedding + ti_tokens.append(token_id) + + if len(ti_tokens) > 1: + ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:] + + text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = embeddings.astype( + orig_embeddings.dtype + ) + + yield ti_tokenizer, ti_manager + + finally: + # restore + if orig_embeddings is not None: + text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = orig_embeddings diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index d72f55794d3..aed5eb60d57 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -102,7 +102,7 @@ def apply_sdxl_lora_text_encoder2( def apply_lora( cls, model: torch.nn.Module, - loras: List[Tuple[LoRAModel, float]], + loras: List[Tuple[LoRAModel, float]], # THIS IS INCORRECT. IT IS ACTUALLY A LoRAModelRaw prefix: str, ): original_weights = {} @@ -194,6 +194,8 @@ def _get_trigger(ti_name, index): return f"<{trigger}>" def _get_ti_embedding(model_embeddings, ti): + print(f"DEBUG: model_embeddings={type(model_embeddings)}, ti={type(ti)}") + print(f"DEBUG: is it an nn.Module? {isinstance(model_embeddings, torch.nn.Module)}") # for SDXL models, select the embedding that matches the text encoder's dimensions if ti.embedding_2 is not None: return ( @@ -202,6 +204,7 @@ def _get_ti_embedding(model_embeddings, ti): else ti.embedding ) else: + print(f"DEBUG: ti.embedding={type(ti.embedding)}") return ti.embedding # modify tokenizer diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index e59a84d7291..4488f8eafc5 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -28,9 +28,11 @@ from pydantic import BaseModel, ConfigDict, Field, TypeAdapter from typing_extensions import Annotated, Any, Dict -from .onnx_runtime import IAIOnnxRuntimeModel +from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel + from ..ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus + class InvalidModelConfigException(Exception): """Exception for when config parser doesn't recognized this combination of model type and format.""" diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index 453283e9b4a..adc84d20516 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -10,11 +10,17 @@ from diffusers.configuration_utils import ConfigMixin from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_manager import AnyModel, AnyModelConfig, InvalidModelConfigException, ModelRepoVariant, SubModelType +from invokeai.backend.model_manager import ( + AnyModel, + AnyModelConfig, + InvalidModelConfigException, + ModelRepoVariant, + SubModelType, +) from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase -from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs, calc_model_size_by_data +from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init from invokeai.backend.util.devices import choose_torch_device, torch_dtype @@ -160,4 +166,3 @@ def _load_model( submodel_type: Optional[SubModelType] = None, ) -> AnyModel: raise NotImplementedError - diff --git a/invokeai/backend/model_manager/load/memory_snapshot.py b/invokeai/backend/model_manager/load/memory_snapshot.py index 295be0c5514..346f5dc4247 100644 --- a/invokeai/backend/model_manager/load/memory_snapshot.py +++ b/invokeai/backend/model_manager/load/memory_snapshot.py @@ -97,4 +97,4 @@ def get_msg_line(prefix: str, val1: int, val2: int) -> str: if snapshot_1.vram is not None and snapshot_2.vram is not None: msg += get_msg_line("VRAM", snapshot_1.vram, snapshot_2.vram) - return "\n"+msg if len(msg)>0 else msg + return "\n" + msg if len(msg) > 0 else msg diff --git a/invokeai/backend/model_manager/load/model_cache/__init__.py b/invokeai/backend/model_manager/load/model_cache/__init__.py index 50cafa37696..6c87e2519e5 100644 --- a/invokeai/backend/model_manager/load/model_cache/__init__.py +++ b/invokeai/backend/model_manager/load/model_cache/__init__.py @@ -1,5 +1,3 @@ """Init file for RamCache.""" -from .model_cache_base import ModelCacheBase -from .model_cache_default import ModelCache _all__ = ["ModelCacheBase", "ModelCache"] diff --git a/invokeai/backend/model_manager/load/model_loaders/controlnet.py b/invokeai/backend/model_manager/load/model_loaders/controlnet.py index 8e6a80ceb20..e61e2b46a63 100644 --- a/invokeai/backend/model_manager/load/model_loaders/controlnet.py +++ b/invokeai/backend/model_manager/load/model_loaders/controlnet.py @@ -14,8 +14,10 @@ ) from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers from invokeai.backend.model_manager.load.load_base import AnyModelLoader + from .generic_diffusers import GenericDiffusersLoader + @AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers) @AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint) class ControlnetLoader(GenericDiffusersLoader): @@ -37,7 +39,7 @@ def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}: raise Exception(f"Vae conversion not supported for model type: {config.base}") else: - assert hasattr(config, 'config') + assert hasattr(config, "config") config_file = config.config if weights_path.suffix == ".safetensors": diff --git a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py index f92a9048c50..03c26f3a0c0 100644 --- a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +++ b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py @@ -15,6 +15,7 @@ from invokeai.backend.model_manager.load.load_base import AnyModelLoader from invokeai.backend.model_manager.load.load_default import ModelLoader + @AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers) @AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers) class GenericDiffusersLoader(ModelLoader): diff --git a/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py b/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py index 63dc3790f16..27ced41c1e9 100644 --- a/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py +++ b/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py @@ -1,11 +1,11 @@ # Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team """Class for IP Adapter model loading in InvokeAI.""" -import torch - from pathlib import Path from typing import Optional +import torch + from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter from invokeai.backend.model_manager import ( AnyModel, @@ -18,6 +18,7 @@ from invokeai.backend.model_manager.load.load_base import AnyModelLoader from invokeai.backend.model_manager.load.load_default import ModelLoader + @AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI) class IPAdapterInvokeAILoader(ModelLoader): """Class to load IP Adapter diffusers models.""" @@ -36,4 +37,3 @@ def _load_model( dtype=self._torch_dtype, ) return model - diff --git a/invokeai/backend/model_manager/load/model_loaders/lora.py b/invokeai/backend/model_manager/load/model_loaders/lora.py index 4d19aadb7d2..d8e5f920e24 100644 --- a/invokeai/backend/model_manager/load/model_loaders/lora.py +++ b/invokeai/backend/model_manager/load/model_loaders/lora.py @@ -2,13 +2,12 @@ """Class for LoRA model loading in InvokeAI.""" +from logging import Logger from pathlib import Path from typing import Optional, Tuple -from logging import Logger -from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase -from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend.embeddings.lora import LoRAModelRaw from invokeai.backend.model_manager import ( AnyModel, AnyModelConfig, @@ -18,9 +17,11 @@ ModelType, SubModelType, ) -from invokeai.backend.model_manager.lora import LoRAModelRaw +from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.load_base import AnyModelLoader from invokeai.backend.model_manager.load.load_default import ModelLoader +from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase + @AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Diffusers) @AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Lycoris) @@ -47,6 +48,7 @@ def _load_model( ) -> AnyModel: if submodel_type is not None: raise ValueError("There are no submodels in a LoRA model.") + assert self._model_base is not None model = LoRAModelRaw.from_checkpoint( file_path=model_path, dtype=self._torch_dtype, @@ -56,9 +58,11 @@ def _load_model( # override def _get_model_path( - self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None + self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None ) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]: - self._model_base = config.base # cheating a little - setting this variable for later call to _load_model() + self._model_base = ( + config.base + ) # cheating a little - we remember this variable for using in the subsequent call to _load_model() model_base_path = self._app_config.models_path model_path = model_base_path / config.path @@ -72,5 +76,3 @@ def _get_model_path( result = model_path.resolve(), config, submodel_type return result - - diff --git a/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py b/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py new file mode 100644 index 00000000000..394fddc75d0 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py @@ -0,0 +1,55 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Class for TI model loading in InvokeAI.""" + + +from pathlib import Path +from typing import Optional, Tuple + +from invokeai.backend.embeddings.model_patcher import TextualInversionModel as TextualInversionModelRaw +from invokeai.backend.model_manager import ( + AnyModel, + AnyModelConfig, + BaseModelType, + ModelFormat, + ModelRepoVariant, + ModelType, + SubModelType, +) +from invokeai.backend.model_manager.load.load_base import AnyModelLoader +from invokeai.backend.model_manager.load.load_default import ModelLoader + + +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFile) +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFolder) +class TextualInversionLoader(ModelLoader): + """Class to load TI models.""" + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if submodel_type is not None: + raise ValueError("There are no submodels in a TI model.") + model = TextualInversionModelRaw.from_checkpoint( + file_path=model_path, + dtype=self._torch_dtype, + ) + return model + + # override + def _get_model_path( + self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None + ) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]: + model_path = self._app_config.models_path / config.path + + if config.format == ModelFormat.EmbeddingFolder: + path = model_path / "learned_embeds.bin" + else: + path = model_path + + if not path.exists(): + raise OSError(f"The embedding file at {path} was not found") + + return path, config, submodel_type diff --git a/invokeai/backend/model_manager/load/model_loaders/vae.py b/invokeai/backend/model_manager/load/model_loaders/vae.py index 7a35e53459a..882ae055771 100644 --- a/invokeai/backend/model_manager/load/model_loaders/vae.py +++ b/invokeai/backend/model_manager/load/model_loaders/vae.py @@ -15,6 +15,7 @@ ) from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers from invokeai.backend.model_manager.load.load_base import AnyModelLoader + from .generic_diffusers import GenericDiffusersLoader diff --git a/invokeai/backend/model_manager/load/model_util.py b/invokeai/backend/model_manager/load/model_util.py index 404c88bbbcd..3f2d22595e2 100644 --- a/invokeai/backend/model_manager/load/model_util.py +++ b/invokeai/backend/model_manager/load/model_util.py @@ -3,13 +3,13 @@ import json from pathlib import Path -from typing import Optional, Union +from typing import Optional import torch from diffusers import DiffusionPipeline from invokeai.backend.model_manager.config import AnyModel -from invokeai.backend.model_manager.onnx_runtime import IAIOnnxRuntimeModel +from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel def calc_model_size_by_data(model: AnyModel) -> int: diff --git a/invokeai/backend/model_manager/onnx_runtime.py b/invokeai/backend/onnx/onnx_runtime.py similarity index 100% rename from invokeai/backend/model_manager/onnx_runtime.py rename to invokeai/backend/onnx/onnx_runtime.py From fbded1c0f2c27274ec4f2d0d9c22bff9bb5bab39 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 5 Feb 2024 21:55:11 -0500 Subject: [PATCH 077/100] Multiple refinements on loaders: - Cache stat collection enabled. - Implemented ONNX loading. - Add ability to specify the repo version variant in installer CLI. - If caller asks for a repo version that doesn't exist, will fall back to empty version rather than raising an error. --- .../model_install/model_install_default.py | 6 +-- invokeai/backend/install/install_helper.py | 18 ++++++-- invokeai/backend/model_manager/config.py | 14 ++++++- .../backend/model_manager/load/__init__.py | 1 - .../backend/model_manager/load/load_base.py | 4 +- .../model_manager/load/load_default.py | 32 ++++++++++---- .../load/model_cache/__init__.py | 3 +- .../load/model_cache/model_cache_base.py | 10 ++++- .../load/model_cache/model_cache_default.py | 42 +++++++++++++++++-- .../model_manager/load/model_loaders/onnx.py | 41 ++++++++++++++++++ .../model_manager/metadata/fetch/civitai.py | 7 +++- .../metadata/fetch/fetch_base.py | 7 +++- .../metadata/fetch/huggingface.py | 26 ++++++++---- .../model_manager/metadata/metadata_base.py | 1 - invokeai/backend/model_manager/probe.py | 16 +++++-- .../model_manager/util/select_hf_files.py | 14 +++++-- invokeai/backend/util/devices.py | 20 ++++++--- invokeai/frontend/install/model_install2.py | 2 +- 18 files changed, 215 insertions(+), 49 deletions(-) create mode 100644 invokeai/backend/model_manager/load/model_loaders/onnx.py diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 1c188b300df..d32af4a513d 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -495,10 +495,10 @@ def _next_id(self) -> int: return id @staticmethod - def _guess_variant() -> ModelRepoVariant: + def _guess_variant() -> Optional[ModelRepoVariant]: """Guess the best HuggingFace variant type to download.""" precision = choose_precision(choose_torch_device()) - return ModelRepoVariant.FP16 if precision == "float16" else ModelRepoVariant.DEFAULT + return ModelRepoVariant.FP16 if precision == "float16" else None def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob: return ModelInstallJob( @@ -523,7 +523,7 @@ def _import_from_hf(self, source: HFModelSource, config: Optional[Dict[str, Any] if not source.access_token: self._logger.info("No HuggingFace access token present; some models may not be downloadable.") - metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id) + metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant) assert isinstance(metadata, ModelMetadataWithFiles) remote_files = metadata.download_urls( variant=source.variant or self._guess_variant(), diff --git a/invokeai/backend/install/install_helper.py b/invokeai/backend/install/install_helper.py index 9f219132d4d..57dfadcaeaa 100644 --- a/invokeai/backend/install/install_helper.py +++ b/invokeai/backend/install/install_helper.py @@ -30,6 +30,7 @@ from invokeai.backend.model_manager import ( BaseModelType, InvalidModelConfigException, + ModelRepoVariant, ModelType, ) from invokeai.backend.model_manager.metadata import UnknownMetadataException @@ -233,11 +234,18 @@ def _make_install_source(self, model_info: UnifiedModelInfo) -> ModelSource: if model_path.exists(): # local file on disk return LocalModelSource(path=model_path.absolute(), inplace=True) - if re.match(r"^[^/]+/[^/]+$", model_path_id_or_url): # hugging face repo_id + + # parsing huggingface repo ids + # we're going to do a little trick that allows for extended repo_ids of form "foo/bar:fp16" + variants = "|".join([x.lower() for x in ModelRepoVariant.__members__]) + if match := re.match(f"^([^/]+/[^/]+?)(?::({variants}))?$", model_path_id_or_url): + repo_id = match.group(1) + repo_variant = ModelRepoVariant(match.group(2)) if match.group(2) else None return HFModelSource( - repo_id=model_path_id_or_url, + repo_id=repo_id, access_token=HfFolder.get_token(), subfolder=model_info.subfolder, + variant=repo_variant, ) if re.match(r"^(http|https):", model_path_id_or_url): return URLModelSource(url=AnyHttpUrl(model_path_id_or_url)) @@ -278,9 +286,11 @@ def add_or_delete(self, selections: InstallSelections) -> None: model_name=model_name, ) if len(matches) > 1: - print(f"{model} is ambiguous. Please use model_type:model_name (e.g. main:my_model) to disambiguate.") + print( + f"{model_to_remove} is ambiguous. Please use model_base/model_type/model_name (e.g. sd-1/main/my_model) to disambiguate." + ) elif not matches: - print(f"{model}: unknown model") + print(f"{model_to_remove}: unknown model") else: for m in matches: print(f"Deleting {m.type}:{m.name}") diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 4488f8eafc5..49ce6af2b81 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -109,7 +109,7 @@ class SchedulerPredictionType(str, Enum): class ModelRepoVariant(str, Enum): """Various hugging face variants on the diffusers format.""" - DEFAULT = "default" # model files without "fp16" or other qualifier + DEFAULT = "" # model files without "fp16" or other qualifier - empty str FP16 = "fp16" FP32 = "fp32" ONNX = "onnx" @@ -246,6 +246,16 @@ class ONNXSD2Config(_MainConfig): upcast_attention: bool = True +class ONNXSDXLConfig(_MainConfig): + """Model config for ONNX format models based on sdxl.""" + + type: Literal[ModelType.ONNX] = ModelType.ONNX + format: Literal[ModelFormat.Onnx, ModelFormat.Olive] + # No yaml config file for ONNX, so these are part of config + base: Literal[BaseModelType.StableDiffusionXL] = BaseModelType.StableDiffusionXL + prediction_type: SchedulerPredictionType = SchedulerPredictionType.VPrediction + + class IPAdapterConfig(ModelConfigBase): """Model config for IP Adaptor format models.""" @@ -267,7 +277,7 @@ class T2IConfig(ModelConfigBase): format: Literal[ModelFormat.Diffusers] -_ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator="base")] +_ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config, ONNXSDXLConfig], Field(discriminator="base")] _ControlNetConfig = Annotated[ Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig], Field(discriminator="format"), diff --git a/invokeai/backend/model_manager/load/__init__.py b/invokeai/backend/model_manager/load/__init__.py index 19b0116ba3b..e4c7077f783 100644 --- a/invokeai/backend/model_manager/load/__init__.py +++ b/invokeai/backend/model_manager/load/__init__.py @@ -16,7 +16,6 @@ # This registers the subclasses that implement loaders of specific model types loaders = [x.stem for x in Path(Path(__file__).parent, "model_loaders").glob("*.py") if x.stem != "__init__"] for module in loaders: - print(f"module={module}") import_module(f"{__package__}.model_loaders.{module}") __all__ = ["AnyModelLoader", "LoadedModel"] diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index 7d4e8337c3c..ee9d6d53e3d 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -22,6 +22,7 @@ from invokeai.backend.model_manager.config import VaeCheckpointConfig, VaeDiffusersConfig from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase +from invokeai.backend.util.logging import InvokeAILogger @dataclass @@ -88,6 +89,7 @@ class AnyModelLoader: # this tracks the loader subclasses _registry: Dict[str, Type[ModelLoaderBase]] = {} + _logger: Logger = InvokeAILogger.get_logger() def __init__( self, @@ -167,7 +169,7 @@ def register( """Define a decorator which registers the subclass of loader.""" def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]: - print("DEBUG: Registering class", subclass.__name__) + cls._logger.debug(f"Registering class {subclass.__name__} to load models of type {base}/{type}/{format}") key = cls._to_registry_key(base, type, format) cls._registry[key] = subclass return subclass diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index adc84d20516..757745072d1 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -52,7 +52,7 @@ def __init__( self._logger = logger self._ram_cache = ram_cache self._convert_cache = convert_cache - self._torch_dtype = torch_dtype(choose_torch_device()) + self._torch_dtype = torch_dtype(choose_torch_device(), app_config) def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: """ @@ -102,8 +102,10 @@ def _load_if_needed( self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None ) -> ModelLockerBase: # TO DO: This is not thread safe! - if self._ram_cache.exists(config.key, submodel_type): + try: return self._ram_cache.get(config.key, submodel_type) + except IndexError: + pass model_variant = getattr(config, "repo_variant", None) self._ram_cache.make_room(self.get_size_fs(config, model_path, submodel_type)) @@ -119,7 +121,11 @@ def _load_if_needed( size=calc_model_size_by_data(loaded_model), ) - return self._ram_cache.get(config.key, submodel_type) + return self._ram_cache.get( + key=config.key, + submodel_type=submodel_type, + stats_name=":".join([config.base, config.type, config.name, (submodel_type or "")]), + ) def get_size_fs( self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None @@ -146,13 +152,21 @@ def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # # TO DO: Add exception handling def _get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin: if submodel_type: - config = self._load_diffusers_config(model_path, config_name="model_index.json") - module, class_name = config[submodel_type.value] - return self._hf_definition_to_type(module=module, class_name=class_name) + try: + config = self._load_diffusers_config(model_path, config_name="model_index.json") + module, class_name = config[submodel_type.value] + return self._hf_definition_to_type(module=module, class_name=class_name) + except KeyError as e: + raise InvalidModelConfigException( + f'The "{submodel_type}" submodel is not available for this model.' + ) from e else: - config = self._load_diffusers_config(model_path, config_name="config.json") - class_name = config["_class_name"] - return self._hf_definition_to_type(module="diffusers", class_name=class_name) + try: + config = self._load_diffusers_config(model_path, config_name="config.json") + class_name = config["_class_name"] + return self._hf_definition_to_type(module="diffusers", class_name=class_name) + except KeyError as e: + raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e # This needs to be implemented in subclasses that handle checkpoints def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path: diff --git a/invokeai/backend/model_manager/load/model_cache/__init__.py b/invokeai/backend/model_manager/load/model_cache/__init__.py index 6c87e2519e5..0cb5184f3a4 100644 --- a/invokeai/backend/model_manager/load/model_cache/__init__.py +++ b/invokeai/backend/model_manager/load/model_cache/__init__.py @@ -1,3 +1,4 @@ -"""Init file for RamCache.""" +"""Init file for ModelCache.""" + _all__ = ["ModelCacheBase", "ModelCache"] diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py index 14a7dfb4a1f..b1a6768ee8f 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py @@ -129,11 +129,17 @@ def get( self, key: str, submodel_type: Optional[SubModelType] = None, + stats_name: Optional[str] = None, ) -> ModelLockerBase: """ - Retrieve model locker object using key and optional submodel_type. + Retrieve model using key and optional submodel_type. - This may return an UnknownModelException if the model is not in the cache. + :param key: Opaque model key + :param submodel_type: Type of the submodel to fetch + :param stats_name: A human-readable id for the model for the purposes of + stats reporting. + + This may raise an IndexError if the model is not in the cache. """ pass diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index 688be8ceb48..7e30512a588 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -24,6 +24,7 @@ import sys import time from contextlib import suppress +from dataclasses import dataclass, field from logging import Logger from typing import Dict, List, Optional @@ -55,6 +56,20 @@ MB = 2**20 +@dataclass +class CacheStats(object): + """Collect statistics on cache performance.""" + + hits: int = 0 # cache hits + misses: int = 0 # cache misses + high_watermark: int = 0 # amount of cache used + in_cache: int = 0 # number of models in cache + cleared: int = 0 # number of models cleared to make space + cache_size: int = 0 # total size of cache + # {submodel_key => size} + loaded_model_sizes: Dict[str, int] = field(default_factory=dict) + + class ModelCache(ModelCacheBase[AnyModel]): """Implementation of ModelCacheBase.""" @@ -94,6 +109,8 @@ def __init__( self._storage_device: torch.device = storage_device self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__) self._log_memory_usage = log_memory_usage or self._logger.level == logging.DEBUG + # used for stats collection + self.stats = CacheStats() self._cached_models: Dict[str, CacheRecord[AnyModel]] = {} self._cache_stack: List[str] = [] @@ -158,21 +175,40 @@ def get( self, key: str, submodel_type: Optional[SubModelType] = None, + stats_name: Optional[str] = None, ) -> ModelLockerBase: """ Retrieve model using key and optional submodel_type. - This may return an IndexError if the model is not in the cache. + :param key: Opaque model key + :param submodel_type: Type of the submodel to fetch + :param stats_name: A human-readable id for the model for the purposes of + stats reporting. + + This may raise an IndexError if the model is not in the cache. """ key = self._make_cache_key(key, submodel_type) - if key not in self._cached_models: + if key in self._cached_models: + self.stats.hits += 1 + else: + self.stats.misses += 1 raise IndexError(f"The model with key {key} is not in the cache.") + cache_entry = self._cached_models[key] + + # more stats + stats_name = stats_name or key + self.stats.cache_size = int(self._max_cache_size * GIG) + self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size()) + self.stats.in_cache = len(self._cached_models) + self.stats.loaded_model_sizes[stats_name] = max( + self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size + ) + # this moves the entry to the top (right end) of the stack with suppress(Exception): self._cache_stack.remove(key) self._cache_stack.append(key) - cache_entry = self._cached_models[key] return ModelLocker( cache=self, cache_entry=cache_entry, diff --git a/invokeai/backend/model_manager/load/model_loaders/onnx.py b/invokeai/backend/model_manager/load/model_loaders/onnx.py new file mode 100644 index 00000000000..935a6b7c953 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/onnx.py @@ -0,0 +1,41 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Class for Onnx model loading in InvokeAI.""" + +# This should work the same as Stable Diffusion pipelines +from pathlib import Path +from typing import Optional + +from invokeai.backend.model_manager import ( + AnyModel, + BaseModelType, + ModelFormat, + ModelRepoVariant, + ModelType, + SubModelType, +) +from invokeai.backend.model_manager.load.load_base import AnyModelLoader +from invokeai.backend.model_manager.load.load_default import ModelLoader + + +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Onnx) +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive) +class OnnyxDiffusersModel(ModelLoader): + """Class to load onnx models.""" + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if not submodel_type is not None: + raise Exception("A submodel type must be provided when loading onnx pipelines.") + load_class = self._get_hf_load_class(model_path, submodel_type) + variant = model_variant.value if model_variant else None + model_path = model_path / submodel_type.value + result: AnyModel = load_class.from_pretrained( + model_path, + torch_dtype=self._torch_dtype, + variant=variant, + ) # type: ignore + return result diff --git a/invokeai/backend/model_manager/metadata/fetch/civitai.py b/invokeai/backend/model_manager/metadata/fetch/civitai.py index 6e41d6f11b2..7991f6a7489 100644 --- a/invokeai/backend/model_manager/metadata/fetch/civitai.py +++ b/invokeai/backend/model_manager/metadata/fetch/civitai.py @@ -32,6 +32,8 @@ from pydantic.networks import AnyHttpUrl from requests.sessions import Session +from invokeai.backend.model_manager import ModelRepoVariant + from ..metadata_base import ( AnyModelRepoMetadata, CivitaiMetadata, @@ -82,10 +84,13 @@ def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata: return self.from_civitai_versionid(int(version_id)) raise UnknownMetadataException("The url '{url}' does not match any known Civitai URL patterns") - def from_id(self, id: str) -> AnyModelRepoMetadata: + def from_id(self, id: str, variant: Optional[ModelRepoVariant] = None) -> AnyModelRepoMetadata: """ Given a Civitai model version ID, return a ModelRepoMetadata object. + :param id: An ID. + :param variant: A model variant from the ModelRepoVariant enum (currently ignored) + May raise an `UnknownMetadataException`. """ return self.from_civitai_versionid(int(id)) diff --git a/invokeai/backend/model_manager/metadata/fetch/fetch_base.py b/invokeai/backend/model_manager/metadata/fetch/fetch_base.py index 58b65b69477..d628ab5c178 100644 --- a/invokeai/backend/model_manager/metadata/fetch/fetch_base.py +++ b/invokeai/backend/model_manager/metadata/fetch/fetch_base.py @@ -18,6 +18,8 @@ from pydantic.networks import AnyHttpUrl from requests.sessions import Session +from invokeai.backend.model_manager import ModelRepoVariant + from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator @@ -45,10 +47,13 @@ def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata: pass @abstractmethod - def from_id(self, id: str) -> AnyModelRepoMetadata: + def from_id(self, id: str, variant: Optional[ModelRepoVariant] = None) -> AnyModelRepoMetadata: """ Given an ID for a model, return a ModelMetadata object. + :param id: An ID. + :param variant: A model variant from the ModelRepoVariant enum. + This method will raise a `UnknownMetadataException` in the event that the requested model's metadata is not found at the provided id. """ diff --git a/invokeai/backend/model_manager/metadata/fetch/huggingface.py b/invokeai/backend/model_manager/metadata/fetch/huggingface.py index 5d1eb0cc9e4..6f04e8713b2 100644 --- a/invokeai/backend/model_manager/metadata/fetch/huggingface.py +++ b/invokeai/backend/model_manager/metadata/fetch/huggingface.py @@ -19,10 +19,12 @@ import requests from huggingface_hub import HfApi, configure_http_backend, hf_hub_url -from huggingface_hub.utils._errors import RepositoryNotFoundError +from huggingface_hub.utils._errors import RepositoryNotFoundError, RevisionNotFoundError from pydantic.networks import AnyHttpUrl from requests.sessions import Session +from invokeai.backend.model_manager import ModelRepoVariant + from ..metadata_base import ( AnyModelRepoMetadata, HuggingFaceMetadata, @@ -53,12 +55,22 @@ def from_json(cls, json: str) -> HuggingFaceMetadata: metadata = HuggingFaceMetadata.model_validate_json(json) return metadata - def from_id(self, id: str) -> AnyModelRepoMetadata: + def from_id(self, id: str, variant: Optional[ModelRepoVariant] = None) -> AnyModelRepoMetadata: """Return a HuggingFaceMetadata object given the model's repo_id.""" - try: - model_info = HfApi().model_info(repo_id=id, files_metadata=True) - except RepositoryNotFoundError as excp: - raise UnknownMetadataException(f"'{id}' not found. See trace for details.") from excp + # Little loop which tries fetching a revision corresponding to the selected variant. + # If not available, then set variant to None and get the default. + # If this too fails, raise exception. + model_info = None + while not model_info: + try: + model_info = HfApi().model_info(repo_id=id, files_metadata=True, revision=variant) + except RepositoryNotFoundError as excp: + raise UnknownMetadataException(f"'{id}' not found. See trace for details.") from excp + except RevisionNotFoundError: + if variant is None: + raise + else: + variant = None _, name = id.split("/") return HuggingFaceMetadata( @@ -70,7 +82,7 @@ def from_id(self, id: str) -> AnyModelRepoMetadata: tags=model_info.tags, files=[ RemoteModelFile( - url=hf_hub_url(id, x.rfilename), + url=hf_hub_url(id, x.rfilename, revision=variant), path=Path(name, x.rfilename), size=x.size, sha256=x.lfs.get("sha256") if x.lfs else None, diff --git a/invokeai/backend/model_manager/metadata/metadata_base.py b/invokeai/backend/model_manager/metadata/metadata_base.py index 5aa883d26d0..5c3afcdc960 100644 --- a/invokeai/backend/model_manager/metadata/metadata_base.py +++ b/invokeai/backend/model_manager/metadata/metadata_base.py @@ -184,7 +184,6 @@ def download_urls( [x.path for x in self.files], variant, subfolder ) # all files in the model prefix = f"{subfolder}/" if subfolder else "" - # the next step reads model_index.json to determine which subdirectories belong # to the model if Path(f"{prefix}model_index.json") in paths: diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 64a20a20923..55a9c0464a5 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -7,6 +7,7 @@ import torch from picklescan.scanner import scan_file_path +import invokeai.backend.util.logging as logger from invokeai.backend.model_management.models.base import read_checkpoint_meta from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat from invokeai.backend.model_management.util import lora_token_vector_length @@ -590,13 +591,20 @@ def get_base_type(self) -> BaseModelType: return TextualInversionCheckpointProbe(path).get_base_type() -class ONNXFolderProbe(FolderProbeBase): +class ONNXFolderProbe(PipelineFolderProbe): + def get_base_type(self) -> BaseModelType: + # Due to the way the installer is set up, the configuration file for safetensors + # will come along for the ride if both the onnx and safetensors forms + # share the same directory. We take advantage of this here. + if (self.model_path / "unet" / "config.json").exists(): + return super().get_base_type() + else: + logger.warning('Base type probing is not implemented for ONNX models. Assuming "sd-1"') + return BaseModelType.StableDiffusion1 + def get_format(self) -> ModelFormat: return ModelFormat("onnx") - def get_base_type(self) -> BaseModelType: - return BaseModelType.StableDiffusion1 - def get_variant_type(self) -> ModelVariantType: return ModelVariantType.Normal diff --git a/invokeai/backend/model_manager/util/select_hf_files.py b/invokeai/backend/model_manager/util/select_hf_files.py index 69760590440..a894d915de6 100644 --- a/invokeai/backend/model_manager/util/select_hf_files.py +++ b/invokeai/backend/model_manager/util/select_hf_files.py @@ -41,13 +41,21 @@ def filter_files( for file in files: if file.name.endswith((".json", ".txt")): paths.append(file) - elif file.name.endswith(("learned_embeds.bin", "ip_adapter.bin", "lora_weights.safetensors")): + elif file.name.endswith( + ( + "learned_embeds.bin", + "ip_adapter.bin", + "lora_weights.safetensors", + "weights.pb", + "onnx_data", + ) + ): paths.append(file) # BRITTLENESS WARNING!! # Diffusers models always seem to have "model" in their name, and the regex filter below is applied to avoid # downloading random checkpoints that might also be in the repo. However there is no guarantee # that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models - # will adhere to this naming convention, so this is an area of brittleness. + # will adhere to this naming convention, so this is an area to be careful of. elif re.search(r"model(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name): paths.append(file) @@ -64,7 +72,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path result = set() basenames: Dict[Path, Path] = {} for path in files: - if path.suffix == ".onnx": + if path.suffix in [".onnx", ".pb", ".onnx_data"]: if variant == ModelRepoVariant.ONNX: result.add(path) diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index a787f9b6f42..b4f24d8483b 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -29,12 +29,17 @@ def choose_torch_device() -> torch.device: return torch.device(config.device) -def choose_precision(device: torch.device) -> str: - """Returns an appropriate precision for the given torch device""" +# We are in transition here from using a single global AppConfig to allowing multiple +# configurations. It is strongly recommended to pass the app_config to this function. +def choose_precision(device: torch.device, app_config: Optional[InvokeAIAppConfig] = None) -> str: + """Return an appropriate precision for the given torch device.""" + app_config = app_config or config if device.type == "cuda": device_name = torch.cuda.get_device_name(device) if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name): - if config.precision == "bfloat16": + if app_config.precision == "float32": + return "float32" + elif app_config.precision == "bfloat16": return "bfloat16" else: return "float16" @@ -43,9 +48,14 @@ def choose_precision(device: torch.device) -> str: return "float32" -def torch_dtype(device: Optional[torch.device] = None) -> torch.dtype: +# We are in transition here from using a single global AppConfig to allowing multiple +# configurations. It is strongly recommended to pass the app_config to this function. +def torch_dtype( + device: Optional[torch.device] = None, + app_config: Optional[InvokeAIAppConfig] = None, +) -> torch.dtype: device = device or choose_torch_device() - precision = choose_precision(device) + precision = choose_precision(device, app_config) if precision == "float16": return torch.float16 if precision == "bfloat16": diff --git a/invokeai/frontend/install/model_install2.py b/invokeai/frontend/install/model_install2.py index 6eb480c8d9d..51a633a5654 100644 --- a/invokeai/frontend/install/model_install2.py +++ b/invokeai/frontend/install/model_install2.py @@ -505,7 +505,7 @@ def list_models(installer: ModelInstallService, model_type: ModelType): print(f"Installed models of type `{model_type}`:") for model in models: path = (config.models_path / model.path).resolve() - print(f"{model.name:40}{model.base.value:14}{path}") + print(f"{model.name:40}{model.base.value:5}{model.type.value:8}{model.format.value:12}{path}") # -------------------------------------------------------- From dfcf38be91bc11d5d9ba79536eafc1cbc61f8260 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 5 Feb 2024 22:56:32 -0500 Subject: [PATCH 078/100] BREAKING CHANGES: invocations now require model key, not base/type/name - Implement new model loader and modify invocations and embeddings - Finish implementation loaders for all models currently supported by InvokeAI. - Move lora, textual_inversion, and model patching support into backend/embeddings. - Restore support for model cache statistics collection (a little ugly, needs work). - Fixed up invocations that load and patch models. - Move seamless and silencewarnings utils into better location --- invokeai/app/api/routers/download_queue.py | 2 +- invokeai/app/invocations/compel.py | 114 +++++++----- .../controlnet_image_processors.py | 7 +- invokeai/app/invocations/ip_adapter.py | 49 +++-- invokeai/app/invocations/latent.py | 86 +++++---- invokeai/app/invocations/model.py | 176 +++++------------- invokeai/app/invocations/sdxl.py | 74 ++------ invokeai/app/invocations/t2i_adapter.py | 8 +- invokeai/app/services/events/events_base.py | 27 +-- .../invocation_stats_default.py | 16 +- .../model_records/model_records_base.py | 47 ++++- .../model_records/model_records_sql.py | 92 ++++++++- invokeai/backend/embeddings/__init__.py | 4 + invokeai/backend/embeddings/embedding_base.py | 12 ++ invokeai/backend/embeddings/lora.py | 14 +- invokeai/backend/embeddings/model_patcher.py | 134 +++---------- .../backend/embeddings/textual_inversion.py | 100 ++++++++++ invokeai/backend/install/install_helper.py | 3 +- invokeai/backend/model_manager/config.py | 5 +- .../backend/model_manager/load/load_base.py | 4 +- .../model_manager/load/load_default.py | 4 +- .../load/model_cache/__init__.py | 4 +- .../load/model_cache/model_cache_base.py | 33 +++- .../load/model_cache/model_cache_default.py | 53 +++--- .../load/model_loaders/textual_inversion.py | 2 +- invokeai/backend/stable_diffusion/__init__.py | 9 + invokeai/backend/stable_diffusion/seamless.py | 102 ++++++++++ invokeai/backend/util/silence_warnings.py | 28 +++ invokeai/frontend/install/model_install2.py | 8 +- .../util/test_hf_model_select.py | 2 + tests/test_model_probe.py | 6 +- 31 files changed, 728 insertions(+), 497 deletions(-) create mode 100644 invokeai/backend/embeddings/__init__.py create mode 100644 invokeai/backend/embeddings/embedding_base.py create mode 100644 invokeai/backend/embeddings/textual_inversion.py create mode 100644 invokeai/backend/stable_diffusion/seamless.py create mode 100644 invokeai/backend/util/silence_warnings.py diff --git a/invokeai/app/api/routers/download_queue.py b/invokeai/app/api/routers/download_queue.py index 92b658c3708..2dba376c181 100644 --- a/invokeai/app/api/routers/download_queue.py +++ b/invokeai/app/api/routers/download_queue.py @@ -55,7 +55,7 @@ async def download( ) -> DownloadJob: """Download the source URL to the file or directory indicted in dest.""" queue = ApiDependencies.invoker.services.download_queue - return queue.download(source, dest, priority, access_token) + return queue.download(source, Path(dest), priority, access_token) @download_queue_router.get( diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 978c6dcb17f..0e1a6bdc6fb 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -1,9 +1,10 @@ -from typing import List, Optional, Union +from typing import Iterator, List, Optional, Tuple, Union import torch from compel import Compel, ReturnedEmbeddingsType from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment +import invokeai.backend.util.logging as logger from invokeai.app.invocations.fields import ( FieldDescriptions, Input, @@ -12,18 +13,21 @@ UIComponent, ) from invokeai.app.invocations.primitives import ConditioningOutput +from invokeai.app.services.model_records import UnknownModelException from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.app.util.ti_utils import extract_ti_triggers_from_prompt +from invokeai.backend.embeddings.lora import LoRAModelRaw +from invokeai.backend.embeddings.model_patcher import ModelPatcher +from invokeai.backend.embeddings.textual_inversion import TextualInversionModelRaw +from invokeai.backend.model_manager import ModelType from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, ConditioningFieldData, ExtraConditioningInfo, SDXLConditioningInfo, ) +from invokeai.backend.util.devices import torch_dtype -from ...backend.model_management.lora import ModelPatcher -from ...backend.model_management.models import ModelNotFoundException, ModelType -from ...backend.util.devices import torch_dtype -from ..util.ti_utils import extract_ti_triggers_from_prompt from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, @@ -64,13 +68,22 @@ class CompelInvocation(BaseInvocation): @torch.no_grad() def invoke(self, context: InvocationContext) -> ConditioningOutput: - tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump()) - text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump()) + tokenizer_info = context.services.model_records.load_model( + **self.clip.tokenizer.model_dump(), + context=context, + ) + text_encoder_info = context.services.model_records.load_model( + **self.clip.text_encoder.model_dump(), + context=context, + ) - def _lora_loader(): + def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in self.clip.loras: - lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) - yield (lora_info.context.model, lora.weight) + lora_info = context.services.model_records.load_model( + **lora.model_dump(exclude={"weight"}), context=context + ) + assert isinstance(lora_info.model, LoRAModelRaw) + yield (lora_info.model, lora.weight) del lora_info return @@ -80,24 +93,20 @@ def _lora_loader(): for trigger in extract_ti_triggers_from_prompt(self.prompt): name = trigger[1:-1] try: - ti_list.append( - ( - name, - context.models.load( - model_name=name, - base_model=self.clip.text_encoder.base_model, - model_type=ModelType.TextualInversion, - ).context.model, - ) - ) - except ModelNotFoundException: + loaded_model = context.services.model_records.load_model( + **self.clip.text_encoder.model_dump(), + context=context, + ).model + assert isinstance(loaded_model, TextualInversionModelRaw) + ti_list.append((name, loaded_model)) + except UnknownModelException: # print(e) # import traceback # print(traceback.format_exc()) print(f'Warn: trigger: "{trigger}" not found') with ( - ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as ( + ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as ( tokenizer, ti_manager, ), @@ -105,7 +114,7 @@ def _lora_loader(): # Apply the LoRA after text_encoder has been moved to its target device for faster patching. ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. - ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers), + ModelPatcher.apply_clip_skip(text_encoder_info.model, self.clip.skipped_layers), ): compel = Compel( tokenizer=tokenizer, @@ -144,6 +153,8 @@ def _lora_loader(): class SDXLPromptInvocationBase: + """Prompt processor for SDXL models.""" + def run_clip_compel( self, context: InvocationContext, @@ -152,20 +163,27 @@ def run_clip_compel( get_pooled: bool, lora_prefix: str, zero_on_empty: bool, - ): - tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump()) - text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump()) + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]: + tokenizer_info = context.services.model_records.load_model( + **clip_field.tokenizer.model_dump(), + context=context, + ) + text_encoder_info = context.services.model_records.load_model( + **clip_field.text_encoder.model_dump(), + context=context, + ) # return zero on empty if prompt == "" and zero_on_empty: - cpu_text_encoder = text_encoder_info.context.model + cpu_text_encoder = text_encoder_info.model + assert isinstance(cpu_text_encoder, torch.nn.Module) c = torch.zeros( ( 1, cpu_text_encoder.config.max_position_embeddings, cpu_text_encoder.config.hidden_size, ), - dtype=text_encoder_info.context.cache.precision, + dtype=cpu_text_encoder.dtype, ) if get_pooled: c_pooled = torch.zeros( @@ -176,10 +194,14 @@ def run_clip_compel( c_pooled = None return c, c_pooled, None - def _lora_loader(): + def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in clip_field.loras: - lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) - yield (lora_info.context.model, lora.weight) + lora_info = context.services.model_records.load_model( + **lora.model_dump(exclude={"weight"}), context=context + ) + lora_model = lora_info.model + assert isinstance(lora_model, LoRAModelRaw) + yield (lora_model, lora.weight) del lora_info return @@ -189,24 +211,24 @@ def _lora_loader(): for trigger in extract_ti_triggers_from_prompt(prompt): name = trigger[1:-1] try: - ti_list.append( - ( - name, - context.models.load( - model_name=name, - base_model=clip_field.text_encoder.base_model, - model_type=ModelType.TextualInversion, - ).context.model, - ) - ) - except ModelNotFoundException: + ti_model = context.services.model_records.load_model_by_attr( + model_name=name, + base_model=text_encoder_info.config.base, + model_type=ModelType.TextualInversion, + context=context, + ).model + assert isinstance(ti_model, TextualInversionModelRaw) + ti_list.append((name, ti_model)) + except UnknownModelException: # print(e) # import traceback # print(traceback.format_exc()) - print(f'Warn: trigger: "{trigger}" not found') + logger.warning(f'trigger: "{trigger}" not found') + except ValueError: + logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models') with ( - ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as ( + ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as ( tokenizer, ti_manager, ), @@ -214,7 +236,7 @@ def _lora_loader(): # Apply the LoRA after text_encoder has been moved to its target device for faster patching. ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. - ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers), + ModelPatcher.apply_clip_skip(text_encoder_info.model, clip_field.skipped_layers), ): compel = Compel( tokenizer=tokenizer, @@ -332,6 +354,7 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput: dim=1, ) + assert c2_pooled is not None conditioning_data = ConditioningFieldData( conditionings=[ SDXLConditioningInfo( @@ -380,6 +403,7 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput: add_time_ids = torch.tensor([original_size + crop_coords + (self.aesthetic_score,)]) + assert c2_pooled is not None conditioning_data = ConditioningFieldData( conditionings=[ SDXLConditioningInfo( diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 37954c1097e..580ee085627 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -23,7 +23,7 @@ ) from controlnet_aux.util import HWC3, ade_palette from PIL import Image -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from pydantic import BaseModel, Field, field_validator, model_validator from invokeai.app.invocations.fields import ( FieldDescriptions, @@ -60,10 +60,7 @@ class ControlNetModelField(BaseModel): """ControlNet model field""" - model_name: str = Field(description="Name of the ControlNet model") - base_model: BaseModelType = Field(description="Base model") - - model_config = ConfigDict(protected_namespaces=()) + key: str = Field(description="Model config record key for the ControlNet model") class ControlField(BaseModel): diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index 845fcfa2848..700b285a45f 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -2,7 +2,8 @@ from builtins import float from typing import List, Union -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from pydantic import BaseModel, Field, field_validator, model_validator +from typing_extensions import Self from invokeai.app.invocations.baseinvocation import ( BaseInvocation, @@ -18,18 +19,13 @@ from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id +# LS: Consider moving these two classes into model.py class IPAdapterModelField(BaseModel): - model_name: str = Field(description="Name of the IP-Adapter model") - base_model: BaseModelType = Field(description="Base model") - - model_config = ConfigDict(protected_namespaces=()) + key: str = Field(description="Key to the IP-Adapter model") class CLIPVisionModelField(BaseModel): - model_name: str = Field(description="Name of the CLIP Vision image encoder model") - base_model: BaseModelType = Field(description="Base model (usually 'Any')") - - model_config = ConfigDict(protected_namespaces=()) + key: str = Field(description="Key to the CLIP Vision image encoder model") class IPAdapterField(BaseModel): @@ -46,16 +42,26 @@ class IPAdapterField(BaseModel): @field_validator("weight") @classmethod - def validate_ip_adapter_weight(cls, v): + def validate_ip_adapter_weight(cls, v: float) -> float: validate_weights(v) return v @model_validator(mode="after") - def validate_begin_end_step_percent(self): + def validate_begin_end_step_percent(self) -> Self: validate_begin_end_step(self.begin_step_percent, self.end_step_percent) return self +def get_ip_adapter_image_encoder_model_id(model_path: str): + """Read the ID of the image encoder associated with the IP-Adapter at `model_path`.""" + image_encoder_config_file = os.path.join(model_path, "image_encoder.txt") + + with open(image_encoder_config_file, "r") as f: + image_encoder_model = f.readline().strip() + + return image_encoder_model + + @invocation_output("ip_adapter_output") class IPAdapterOutput(BaseInvocationOutput): # Outputs @@ -84,33 +90,36 @@ class IPAdapterInvocation(BaseInvocation): @field_validator("weight") @classmethod - def validate_ip_adapter_weight(cls, v): + def validate_ip_adapter_weight(cls, v: float) -> float: validate_weights(v) return v @model_validator(mode="after") - def validate_begin_end_step_percent(self): + def validate_begin_end_step_percent(self) -> Self: validate_begin_end_step(self.begin_step_percent, self.end_step_percent) return self def invoke(self, context: InvocationContext) -> IPAdapterOutput: # Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model. - ip_adapter_info = context.models.get_info( - self.ip_adapter_model.model_name, self.ip_adapter_model.base_model, ModelType.IPAdapter - ) + ip_adapter_info = context.services.model_records.get_model(self.ip_adapter_model.key) # HACK(ryand): This is bad for a couple of reasons: 1) we are bypassing the model manager to read the model # directly, and 2) we are reading from disk every time this invocation is called without caching the result. # A better solution would be to store the image encoder model reference in the IP-Adapter model info, but this # is currently messy due to differences between how the model info is generated when installing a model from # disk vs. downloading the model. + # TODO (LS): Fix the issue above by: + # 1. Change IPAdapterConfig definition to include a field for the repo_id of the image encoder model. + # 2. Update probe.py to read `image_encoder.txt` and store it in the config. + # 3. Change below to get the image encoder from the configuration record. image_encoder_model_id = get_ip_adapter_image_encoder_model_id( - os.path.join(context.config.get().models_path, ip_adapter_info["path"]) + os.path.join(context.services.configuration.get_config().models_path, ip_adapter_info.path) ) image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip() - image_encoder_model = CLIPVisionModelField( - model_name=image_encoder_model_name, - base_model=BaseModelType.Any, + image_encoder_models = context.services.model_records.search_by_attr( + model_name=image_encoder_model_name, base_model=BaseModelType.Any, model_type=ModelType.CLIPVision ) + assert len(image_encoder_models) == 1 + image_encoder_model = CLIPVisionModelField(key=image_encoder_models[0].key) return IPAdapterOutput( ip_adapter=IPAdapterField( image=self.image, diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 69e3f055ca8..063b23fa589 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -3,13 +3,13 @@ import math from contextlib import ExitStack from functools import singledispatchmethod -from typing import List, Literal, Optional, Union +from typing import Iterator, List, Literal, Optional, Tuple, Union import einops import numpy as np import torch import torchvision.transforms as T -from diffusers import AutoencoderKL, AutoencoderTiny +from diffusers import AutoencoderKL, AutoencoderTiny, UNet2DConditionModel from diffusers.image_processor import VaeImageProcessor from diffusers.models.adapter import T2IAdapter from diffusers.models.attention_processor import ( @@ -46,14 +46,13 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.controlnet_utils import prepare_control_image +from invokeai.backend.embeddings.model_patcher import ModelPatcher from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus -from invokeai.backend.model_management.models import ModelType, SilenceWarnings +from invokeai.backend.model_manager import AnyModel, BaseModelType +from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo +from invokeai.backend.util.silence_warnings import SilenceWarnings -from ...backend.model_management.lora import ModelPatcher -from ...backend.model_management.models import BaseModelType -from ...backend.model_management.seamless import set_seamless -from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion.diffusers_pipeline import ( ControlNetData, IPAdapterData, @@ -149,7 +148,10 @@ def invoke(self, context: InvocationContext) -> DenoiseMaskOutput: ) if image is not None: - vae_info = context.models.load(**self.vae.vae.model_dump()) + vae_info = context.services.model_records.load_model( + **self.vae.vae.model_dump(), + context=context, + ) img_mask = tv_resize(mask, image.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) masked_image = image * torch.where(img_mask < 0.5, 0.0, 1.0) @@ -175,7 +177,10 @@ def get_scheduler( seed: int, ) -> Scheduler: scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"]) - orig_scheduler_info = context.models.load(**scheduler_info.model_dump()) + orig_scheduler_info = context.services.model_records.load_model( + **scheduler_info.model_dump(), + context=context, + ) with orig_scheduler_info as orig_scheduler: scheduler_config = orig_scheduler.config @@ -389,10 +394,9 @@ def prep_control_data( controlnet_data = [] for control_info in control_list: control_model = exit_stack.enter_context( - context.models.load( - model_name=control_info.control_model.model_name, - model_type=ModelType.ControlNet, - base_model=control_info.control_model.base_model, + context.services.model_records.load_model( + key=control_info.control_model.key, + context=context, ) ) @@ -456,17 +460,15 @@ def prep_ip_adapter_data( conditioning_data.ip_adapter_conditioning = [] for single_ip_adapter in ip_adapter: ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context( - context.models.load( - model_name=single_ip_adapter.ip_adapter_model.model_name, - model_type=ModelType.IPAdapter, - base_model=single_ip_adapter.ip_adapter_model.base_model, + context.services.model_records.load_model( + key=single_ip_adapter.ip_adapter_model.key, + context=context, ) ) - image_encoder_model_info = context.models.load( - model_name=single_ip_adapter.image_encoder_model.model_name, - model_type=ModelType.CLIPVision, - base_model=single_ip_adapter.image_encoder_model.base_model, + image_encoder_model_info = context.services.model_records.load_model( + key=single_ip_adapter.image_encoder_model.key, + context=context, ) # `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here. @@ -518,10 +520,9 @@ def run_t2i_adapters( t2i_adapter_data = [] for t2i_adapter_field in t2i_adapter: - t2i_adapter_model_info = context.models.load( - model_name=t2i_adapter_field.t2i_adapter_model.model_name, - model_type=ModelType.T2IAdapter, - base_model=t2i_adapter_field.t2i_adapter_model.base_model, + t2i_adapter_model_info = context.services.model_records.load_model( + key=t2i_adapter_field.t2i_adapter_model.key, + context=context, ) image = context.images.get_pil(t2i_adapter_field.image.image_name) @@ -556,7 +557,7 @@ def run_t2i_adapters( do_classifier_free_guidance=False, width=t2i_input_width, height=t2i_input_height, - num_channels=t2i_adapter_model.config.in_channels, + num_channels=t2i_adapter_model.config["in_channels"], # mypy treats this as a FrozenDict device=t2i_adapter_model.device, dtype=t2i_adapter_model.dtype, resize_mode=t2i_adapter_field.resize_mode, @@ -662,22 +663,30 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: def step_callback(state: PipelineIntermediateState): context.util.sd_step_callback(state, self.unet.unet.base_model) - def _lora_loader(): + def _lora_loader() -> Iterator[Tuple[AnyModel, float]]: for lora in self.unet.loras: - lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) - yield (lora_info.context.model, lora.weight) + lora_info = context.services.model_records.load_model( + **lora.model_dump(exclude={"weight"}), + context=context, + ) + yield (lora_info.model, lora.weight) del lora_info return - unet_info = context.models.load(**self.unet.unet.model_dump()) + unet_info = context.services.model_records.load_model( + **self.unet.unet.model_dump(), + context=context, + ) + assert isinstance(unet_info.model, UNet2DConditionModel) with ( ExitStack() as exit_stack, - ModelPatcher.apply_freeu(unet_info.context.model, self.unet.freeu_config), - set_seamless(unet_info.context.model, self.unet.seamless_axes), + ModelPatcher.apply_freeu(unet_info.model, self.unet.freeu_config), + set_seamless(unet_info.model, self.unet.seamless_axes), # FIXME unet_info as unet, # Apply the LoRA after unet has been moved to its target device for faster patching. ModelPatcher.apply_lora_unet(unet, _lora_loader()), ): + assert isinstance(unet, torch.Tensor) latents = latents.to(device=unet.device, dtype=unet.dtype) if noise is not None: noise = noise.to(device=unet.device, dtype=unet.dtype) @@ -774,9 +783,13 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): def invoke(self, context: InvocationContext) -> ImageOutput: latents = context.tensors.load(self.latents.latents_name) - vae_info = context.models.load(**self.vae.vae.model_dump()) + vae_info = context.services.model_records.load_model( + **self.vae.vae.model_dump(), + context=context, + ) - with set_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae: + with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae: + assert isinstance(vae, torch.Tensor) latents = latents.to(vae.device) if self.fp32: vae.to(dtype=torch.float32) @@ -995,7 +1008,10 @@ def vae_encode(vae_info, upcast, tiled, image_tensor): def invoke(self, context: InvocationContext) -> LatentsOutput: image = context.images.get_pil(self.image.image_name) - vae_info = context.models.load(**self.vae.vae.model_dump()) + vae_info = context.services.model_records.load_model( + **self.vae.vae.model_dump(), + context=context, + ) image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) if image_tensor.dim() == 3: diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 6a1fd6d36bc..e2ea7442839 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -1,13 +1,13 @@ import copy from typing import List, Optional -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, Field from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.shared.models import FreeUConfig -from ...backend.model_management import BaseModelType, ModelType, SubModelType +from ...backend.model_manager import SubModelType from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, @@ -17,13 +17,9 @@ class ModelInfo(BaseModel): - model_name: str = Field(description="Info to load submodel") - base_model: BaseModelType = Field(description="Base model") - model_type: ModelType = Field(description="Info to load submodel") + key: str = Field(description="Info to load submodel") submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel") - model_config = ConfigDict(protected_namespaces=()) - class LoraInfo(ModelInfo): weight: float = Field(description="Lora's weight which to use when apply to model") @@ -52,7 +48,7 @@ class VaeField(BaseModel): @invocation_output("unet_output") class UNetOutput(BaseInvocationOutput): - """Base class for invocations that output a UNet field""" + """Base class for invocations that output a UNet field.""" unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet") @@ -81,20 +77,13 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput): class MainModelField(BaseModel): """Main model field""" - model_name: str = Field(description="Name of the model") - base_model: BaseModelType = Field(description="Base model") - model_type: ModelType = Field(description="Model Type") - - model_config = ConfigDict(protected_namespaces=()) + key: str = Field(description="Model key") class LoRAModelField(BaseModel): """LoRA model field""" - model_name: str = Field(description="Name of the LoRA model") - base_model: BaseModelType = Field(description="Base model") - - model_config = ConfigDict(protected_namespaces=()) + key: str = Field(description="LoRA model key") @invocation( @@ -111,74 +100,31 @@ class MainModelLoaderInvocation(BaseInvocation): # TODO: precision? def invoke(self, context: InvocationContext) -> ModelLoaderOutput: - base_model = self.model.base_model - model_name = self.model.model_name - model_type = ModelType.Main + key = self.model.key # TODO: not found exceptions - if not context.models.exists( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ): - raise Exception(f"Unknown {base_model} {model_type} model: {model_name}") - - """ - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.Tokenizer, - ): - raise Exception( - f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted" - ) - - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.TextEncoder, - ): - raise Exception( - f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted" - ) - - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.UNet, - ): - raise Exception( - f"Failed to find unet submodel from {self.model_name}! Check if model corrupted" - ) - """ + if not context.services.model_records.exists(key): + raise Exception(f"Unknown model {key}") return ModelLoaderOutput( unet=UNetField( unet=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=key, submodel=SubModelType.UNet, ), scheduler=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=key, submodel=SubModelType.Scheduler, ), loras=[], ), clip=ClipField( tokenizer=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=key, submodel=SubModelType.Tokenizer, ), text_encoder=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=key, submodel=SubModelType.TextEncoder, ), loras=[], @@ -186,9 +132,7 @@ def invoke(self, context: InvocationContext) -> ModelLoaderOutput: ), vae=VaeField( vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=key, submodel=SubModelType.Vae, ), ), @@ -226,21 +170,16 @@ def invoke(self, context: InvocationContext) -> LoraLoaderOutput: if self.lora is None: raise Exception("No LoRA provided") - base_model = self.lora.base_model - lora_name = self.lora.model_name + lora_key = self.lora.key - if not context.models.exists( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, - ): - raise Exception(f"Unkown lora name: {lora_name}!") + if not context.services.model_records.exists(lora_key): + raise Exception(f"Unkown lora: {lora_key}!") - if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras): - raise Exception(f'Lora "{lora_name}" already applied to unet') + if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras): + raise Exception(f'Lora "{lora_key}" already applied to unet') - if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras): - raise Exception(f'Lora "{lora_name}" already applied to clip') + if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras): + raise Exception(f'Lora "{lora_key}" already applied to clip') output = LoraLoaderOutput() @@ -248,9 +187,7 @@ def invoke(self, context: InvocationContext) -> LoraLoaderOutput: output.unet = copy.deepcopy(self.unet) output.unet.loras.append( LoraInfo( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, + key=lora_key, submodel=None, weight=self.weight, ) @@ -260,9 +197,7 @@ def invoke(self, context: InvocationContext) -> LoraLoaderOutput: output.clip = copy.deepcopy(self.clip) output.clip.loras.append( LoraInfo( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, + key=lora_key, submodel=None, weight=self.weight, ) @@ -315,24 +250,19 @@ def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: if self.lora is None: raise Exception("No LoRA provided") - base_model = self.lora.base_model - lora_name = self.lora.model_name + lora_key = self.lora.key - if not context.models.exists( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, - ): - raise Exception(f"Unknown lora name: {lora_name}!") + if not context.services.model_records.exists(lora_key): + raise Exception(f"Unknown lora: {lora_key}!") - if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras): - raise Exception(f'Lora "{lora_name}" already applied to unet') + if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras): + raise Exception(f'Lora "{lora_key}" already applied to unet') - if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras): - raise Exception(f'Lora "{lora_name}" already applied to clip') + if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras): + raise Exception(f'Lora "{lora_key}" already applied to clip') - if self.clip2 is not None and any(lora.model_name == lora_name for lora in self.clip2.loras): - raise Exception(f'Lora "{lora_name}" already applied to clip2') + if self.clip2 is not None and any(lora.key == lora_key for lora in self.clip2.loras): + raise Exception(f'Lora "{lora_key}" already applied to clip2') output = SDXLLoraLoaderOutput() @@ -340,9 +270,7 @@ def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: output.unet = copy.deepcopy(self.unet) output.unet.loras.append( LoraInfo( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, + key=lora_key, submodel=None, weight=self.weight, ) @@ -352,9 +280,7 @@ def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: output.clip = copy.deepcopy(self.clip) output.clip.loras.append( LoraInfo( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, + key=lora_key, submodel=None, weight=self.weight, ) @@ -364,9 +290,7 @@ def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: output.clip2 = copy.deepcopy(self.clip2) output.clip2.loras.append( LoraInfo( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, + key=lora_key, submodel=None, weight=self.weight, ) @@ -378,10 +302,7 @@ def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: class VAEModelField(BaseModel): """Vae model field""" - model_name: str = Field(description="Name of the model") - base_model: BaseModelType = Field(description="Base model") - - model_config = ConfigDict(protected_namespaces=()) + key: str = Field(description="Model's key") @invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.1") @@ -395,25 +316,12 @@ class VaeLoaderInvocation(BaseInvocation): ) def invoke(self, context: InvocationContext) -> VAEOutput: - base_model = self.vae_model.base_model - model_name = self.vae_model.model_name - model_type = ModelType.Vae - - if not context.models.exists( - base_model=base_model, - model_name=model_name, - model_type=model_type, - ): - raise Exception(f"Unkown vae name: {model_name}!") - return VAEOutput( - vae=VaeField( - vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ) - ) - ) + key = self.vae_model.key + + if not context.services.model_records.exists(key): + raise Exception(f"Unkown vae: {key}!") + + return VAEOutput(vae=VaeField(vae=ModelInfo(key=key))) @invocation_output("seamless_output") diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 8d51674a046..633a6477fdb 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -1,7 +1,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.model_manager import SubModelType -from ...backend.model_management import ModelType, SubModelType from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, @@ -40,45 +40,31 @@ class SDXLModelLoaderInvocation(BaseInvocation): # TODO: precision? def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput: - base_model = self.model.base_model - model_name = self.model.model_name - model_type = ModelType.Main + model_key = self.model.key # TODO: not found exceptions - if not context.models.exists( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ): - raise Exception(f"Unknown {base_model} {model_type} model: {model_name}") + if not context.services.model_records.exists(model_key): + raise Exception(f"Unknown model: {model_key}") return SDXLModelLoaderOutput( unet=UNetField( unet=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.UNet, ), scheduler=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.Scheduler, ), loras=[], ), clip=ClipField( tokenizer=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.Tokenizer, ), text_encoder=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.TextEncoder, ), loras=[], @@ -86,15 +72,11 @@ def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput: ), clip2=ClipField( tokenizer=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.Tokenizer2, ), text_encoder=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.TextEncoder2, ), loras=[], @@ -102,9 +84,7 @@ def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput: ), vae=VaeField( vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.Vae, ), ), @@ -129,45 +109,31 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation): # TODO: precision? def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput: - base_model = self.model.base_model - model_name = self.model.model_name - model_type = ModelType.Main + model_key = self.model.key # TODO: not found exceptions - if not context.models.exists( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ): - raise Exception(f"Unknown {base_model} {model_type} model: {model_name}") + if not context.services.model_records.exists(model_key): + raise Exception(f"Unknown model: {model_key}") return SDXLRefinerModelLoaderOutput( unet=UNetField( unet=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.UNet, ), scheduler=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.Scheduler, ), loras=[], ), clip2=ClipField( tokenizer=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.Tokenizer2, ), text_encoder=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.TextEncoder2, ), loras=[], @@ -175,9 +141,7 @@ def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput: ), vae=VaeField( vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.Vae, ), ), diff --git a/invokeai/app/invocations/t2i_adapter.py b/invokeai/app/invocations/t2i_adapter.py index 0f4fe66ada1..0f1e251bb36 100644 --- a/invokeai/app/invocations/t2i_adapter.py +++ b/invokeai/app/invocations/t2i_adapter.py @@ -1,6 +1,6 @@ from typing import Union -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from pydantic import BaseModel, Field, field_validator, model_validator from invokeai.app.invocations.baseinvocation import ( BaseInvocation, @@ -12,14 +12,10 @@ from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.backend.model_management.models.base import BaseModelType class T2IAdapterModelField(BaseModel): - model_name: str = Field(description="Name of the T2I-Adapter model") - base_model: BaseModelType = Field(description="Base model") - - model_config = ConfigDict(protected_namespaces=()) + key: str = Field(description="Model record key for the T2I-Adapter model") class T2IAdapterField(BaseModel): diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index 6b441efc2bf..90d9068b88c 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -11,8 +11,7 @@ SessionQueueStatus, ) from invokeai.app.util.misc import get_timestamp -from invokeai.backend.model_management.model_manager import LoadedModelInfo -from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType +from invokeai.backend.model_manager import AnyModelConfig class EventServiceBase: @@ -171,10 +170,7 @@ def emit_model_load_started( queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: SubModelType, + model_config: AnyModelConfig, ) -> None: """Emitted when a model is requested""" self.__emit_queue_event( @@ -184,10 +180,7 @@ def emit_model_load_started( "queue_item_id": queue_item_id, "queue_batch_id": queue_batch_id, "graph_execution_state_id": graph_execution_state_id, - "model_name": model_name, - "base_model": base_model, - "model_type": model_type, - "submodel": submodel, + "model_config": model_config.model_dump(), }, ) @@ -197,11 +190,7 @@ def emit_model_load_completed( queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: SubModelType, - loaded_model_info: LoadedModelInfo, + model_config: AnyModelConfig, ) -> None: """Emitted when a model is correctly loaded (returns model info)""" self.__emit_queue_event( @@ -211,13 +200,7 @@ def emit_model_load_completed( "queue_item_id": queue_item_id, "queue_batch_id": queue_batch_id, "graph_execution_state_id": graph_execution_state_id, - "model_name": model_name, - "base_model": base_model, - "model_type": model_type, - "submodel": submodel, - "hash": loaded_model_info.hash, - "location": str(loaded_model_info.location), - "precision": str(loaded_model_info.precision), + "model_config": model_config.model_dump(), }, ) diff --git a/invokeai/app/services/invocation_stats/invocation_stats_default.py b/invokeai/app/services/invocation_stats/invocation_stats_default.py index be58aaad2dd..0c63b545ff2 100644 --- a/invokeai/app/services/invocation_stats/invocation_stats_default.py +++ b/invokeai/app/services/invocation_stats/invocation_stats_default.py @@ -2,6 +2,7 @@ import time from contextlib import contextmanager from pathlib import Path +from typing import Iterator import psutil import torch @@ -10,7 +11,7 @@ from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.services.invoker import Invoker from invokeai.app.services.item_storage.item_storage_common import ItemNotFoundError -from invokeai.backend.model_management.model_cache import CacheStats +from invokeai.backend.model_manager.load.model_cache import CacheStats from .invocation_stats_base import InvocationStatsServiceBase from .invocation_stats_common import ( @@ -41,7 +42,10 @@ def start(self, invoker: Invoker) -> None: self._invoker = invoker @contextmanager - def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str): + def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str) -> Iterator[None]: + services = self._invoker.services + if services.model_records is None or services.model_records.loader is None: + yield None if not self._stats.get(graph_execution_state_id): # First time we're seeing this graph_execution_state_id. self._stats[graph_execution_state_id] = GraphExecutionStats() @@ -55,8 +59,10 @@ def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: st start_ram = psutil.Process().memory_info().rss if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats() - if self._invoker.services.model_manager: - self._invoker.services.model_manager.collect_cache_stats(self._cache_stats[graph_execution_state_id]) + + # TO DO [LS]: clean up loader service - shouldn't be an attribute of model records + assert services.model_records.loader is not None + services.model_records.loader.ram_cache.stats = self._cache_stats[graph_execution_state_id] try: # Let the invocation run. @@ -73,7 +79,7 @@ def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: st ) self._stats[graph_execution_state_id].add_node_execution_stats(node_stats) - def _prune_stale_stats(self): + def _prune_stale_stats(self) -> None: """Check all graphs being tracked and prune any that have completed/errored. This shouldn't be necessary, but we don't have totally robust upstream handling of graph completions/errors, so diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index 42e3c8f83a7..e00dd4169d5 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -10,6 +10,7 @@ from pydantic import BaseModel, Field +from invokeai.app.invocations.baseinvocation import InvocationContext from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.backend.model_manager import ( AnyModelConfig, @@ -19,6 +20,7 @@ ModelType, SubModelType, ) +from invokeai.backend.model_manager.load import AnyModelLoader from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore @@ -110,12 +112,45 @@ def get_model(self, key: str) -> AnyModelConfig: pass @abstractmethod - def load_model(self, key: str, submodel_type: Optional[SubModelType]) -> LoadedModel: + def load_model( + self, + key: str, + submodel: Optional[SubModelType] = None, + context: Optional[InvocationContext] = None, + ) -> LoadedModel: """ Load the indicated model into memory and return a LoadedModel object. :param key: Key of model config to be fetched. - :param submodel_type: For main (pipeline models), the submodel to fetch + :param submodel: For main (pipeline models), the submodel to fetch + :param context: Invocation context, used for event issuing. + + Exceptions: UnknownModelException -- model with this key not known + NotImplementedException -- a model loader was not provided at initialization time + """ + pass + + @abstractmethod + def load_model_by_attr( + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + submodel: Optional[SubModelType] = None, + context: Optional[InvocationContext] = None, + ) -> LoadedModel: + """ + Load the indicated model into memory and return a LoadedModel object. + + This is provided for API compatability with the get_model() method + in the original model manager. However, note that LoadedModel is + not the same as the original ModelInfo that ws returned. + + :param model_name: Key of model config to be fetched. + :param base_model: Base model + :param model_type: Type of the model + :param submodel: For main (pipeline models), the submodel to fetch + :param context: The invocation context. Exceptions: UnknownModelException -- model with this key not known NotImplementedException -- a model loader was not provided at initialization time @@ -166,7 +201,7 @@ def list_models( @abstractmethod def exists(self, key: str) -> bool: """ - Return True if a model with the indicated key exists in the databse. + Return True if a model with the indicated key exists in the database. :param key: Unique key for the model to be deleted """ @@ -209,6 +244,12 @@ def search_by_attr( """ pass + @property + @abstractmethod + def loader(self) -> Optional[AnyModelLoader]: + """Return the model loader used by this instance.""" + pass + def all_models(self) -> List[AnyModelConfig]: """Return all the model configs in the database.""" return self.search_by_attr() diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index b50cd17a75d..28a77b1b1ab 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -46,6 +46,8 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Set, Tuple, Union +from invokeai.app.invocations.baseinvocation import InvocationContext +from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.backend.model_manager.config import ( AnyModelConfig, @@ -88,6 +90,11 @@ def db(self) -> SqliteDatabase: """Return the underlying database.""" return self._db + @property + def loader(self) -> Optional[AnyModelLoader]: + """Return the model loader used by this instance.""" + return self._loader + def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig: """ Add a model to the database. @@ -213,20 +220,73 @@ def get_model(self, key: str) -> AnyModelConfig: model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1]) return model - def load_model(self, key: str, submodel_type: Optional[SubModelType]) -> LoadedModel: + def load_model( + self, + key: str, + submodel: Optional[SubModelType], + context: Optional[InvocationContext] = None, + ) -> LoadedModel: """ Load the indicated model into memory and return a LoadedModel object. :param key: Key of model config to be fetched. - :param submodel_type: For main (pipeline models), the submodel to fetch. + :param submodel: For main (pipeline models), the submodel to fetch. + :param context: Invocation context used for event reporting Exceptions: UnknownModelException -- model with this key not known NotImplementedException -- a model loader was not provided at initialization time """ if not self._loader: raise NotImplementedError(f"Class {self.__class__} was not initialized with a model loader") + # we can emit model loading events if we are executing with access to the invocation context + model_config = self.get_model(key) - return self._loader.load_model(model_config, submodel_type) + if context: + self._emit_load_event( + context=context, + model_config=model_config, + ) + loaded_model = self._loader.load_model(model_config, submodel) + if context: + self._emit_load_event( + context=context, + model_config=model_config, + loaded=True, + ) + return loaded_model + + def load_model_by_attr( + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + submodel: Optional[SubModelType] = None, + context: Optional[InvocationContext] = None, + ) -> LoadedModel: + """ + Load the indicated model into memory and return a LoadedModel object. + + This is provided for API compatability with the get_model() method + in the original model manager. However, note that LoadedModel is + not the same as the original ModelInfo that ws returned. + + :param model_name: Key of model config to be fetched. + :param base_model: Base model + :param model_type: Type of the model + :param submodel: For main (pipeline models), the submodel to fetch + :param context: The invocation context. + + Exceptions: UnknownModelException -- model with this key not known + NotImplementedException -- a model loader was not provided at initialization time + ValueError -- more than one model matches this combination + """ + configs = self.search_by_attr(model_name, base_model, model_type) + if len(configs) == 0: + raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model") + elif len(configs) > 1: + raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.") + else: + return self.load_model(configs[0].key, submodel) def exists(self, key: str) -> bool: """ @@ -416,3 +476,29 @@ def _fixup(summary: Dict[str, str]) -> Dict[str, Union[str, int, Set[str]]]: return PaginatedResults( page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items ) + + def _emit_load_event( + self, + context: InvocationContext, + model_config: AnyModelConfig, + loaded: Optional[bool] = False, + ) -> None: + if context.services.queue.is_canceled(context.graph_execution_state_id): + raise CanceledException() + + if not loaded: + context.services.events.emit_model_load_started( + queue_id=context.queue_id, + queue_item_id=context.queue_item_id, + queue_batch_id=context.queue_batch_id, + graph_execution_state_id=context.graph_execution_state_id, + model_config=model_config, + ) + else: + context.services.events.emit_model_load_completed( + queue_id=context.queue_id, + queue_item_id=context.queue_item_id, + queue_batch_id=context.queue_batch_id, + graph_execution_state_id=context.graph_execution_state_id, + model_config=model_config, + ) diff --git a/invokeai/backend/embeddings/__init__.py b/invokeai/backend/embeddings/__init__.py new file mode 100644 index 00000000000..46ead533c4d --- /dev/null +++ b/invokeai/backend/embeddings/__init__.py @@ -0,0 +1,4 @@ +"""Initialization file for invokeai.backend.embeddings modules.""" + +# from .model_patcher import ModelPatcher +# __all__ = ["ModelPatcher"] diff --git a/invokeai/backend/embeddings/embedding_base.py b/invokeai/backend/embeddings/embedding_base.py new file mode 100644 index 00000000000..5e752a29e14 --- /dev/null +++ b/invokeai/backend/embeddings/embedding_base.py @@ -0,0 +1,12 @@ +"""Base class for LoRA and Textual Inversion models. + +The EmbeddingRaw class is the base class of LoRAModelRaw and TextualInversionModelRaw, +and is used for type checking of calls to the model patcher. + +The use of "Raw" here is a historical artifact, and carried forward in +order to avoid confusion. +""" + + +class EmbeddingModelRaw: + """Base class for LoRA and Textual Inversion models.""" diff --git a/invokeai/backend/embeddings/lora.py b/invokeai/backend/embeddings/lora.py index 9a59a977087..3c7ef074efe 100644 --- a/invokeai/backend/embeddings/lora.py +++ b/invokeai/backend/embeddings/lora.py @@ -11,6 +11,8 @@ from invokeai.backend.model_manager import BaseModelType +from .embedding_base import EmbeddingModelRaw + class LoRALayerBase: # rank: Optional[int] @@ -317,7 +319,7 @@ def to( self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, - ): + ) -> None: super().to(device=device, dtype=dtype) self.weight = self.weight.to(device=device, dtype=dtype) @@ -367,7 +369,7 @@ def to( # TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix -class LoRAModelRaw: # (torch.nn.Module): +class LoRAModelRaw(EmbeddingModelRaw): # (torch.nn.Module): _name: str layers: Dict[str, AnyLoRALayer] @@ -471,16 +473,16 @@ def from_checkpoint( file_path = Path(file_path) model = cls( - name=file_path.stem, # TODO: + name=file_path.stem, layers={}, ) if file_path.suffix == ".safetensors": - state_dict = load_file(file_path.absolute().as_posix(), device="cpu") + sd = load_file(file_path.absolute().as_posix(), device="cpu") else: - state_dict = torch.load(file_path, map_location="cpu") + sd = torch.load(file_path, map_location="cpu") - state_dict = cls._group_state(state_dict) + state_dict = cls._group_state(sd) if base_model == BaseModelType.StableDiffusionXL: state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict) diff --git a/invokeai/backend/embeddings/model_patcher.py b/invokeai/backend/embeddings/model_patcher.py index 6d73235197d..4725181b8ed 100644 --- a/invokeai/backend/embeddings/model_patcher.py +++ b/invokeai/backend/embeddings/model_patcher.py @@ -4,22 +4,20 @@ import pickle from contextlib import contextmanager -from pathlib import Path -from typing import Any, Dict, Generator, List, Optional, Tuple, Union +from typing import Any, Dict, Iterator, List, Optional, Tuple import numpy as np import torch -from compel.embeddings_provider import BaseTextualInversionManager -from diffusers import ModelMixin, OnnxRuntimeModel, UNet2DConditionModel -from safetensors.torch import load_file +from diffusers import OnnxRuntimeModel, UNet2DConditionModel from transformers import CLIPTextModel, CLIPTokenizer -from typing_extensions import Self from invokeai.app.shared.models import FreeUConfig +from invokeai.backend.model_manager import AnyModel from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel from .lora import LoRAModelRaw +from .textual_inversion import TextualInversionManager, TextualInversionModelRaw """ loras = [ @@ -67,7 +65,7 @@ def apply_lora_unet( cls, unet: UNet2DConditionModel, loras: List[Tuple[LoRAModelRaw, float]], - ) -> Generator[None, None, None]: + ) -> None: with cls.apply_lora(unet, loras, "lora_unet_"): yield @@ -76,8 +74,8 @@ def apply_lora_unet( def apply_lora_text_encoder( cls, text_encoder: CLIPTextModel, - loras: List[Tuple[LoRAModelRaw, float]], - ): + loras: Iterator[Tuple[LoRAModelRaw, float]], + ) -> None: with cls.apply_lora(text_encoder, loras, "lora_te_"): yield @@ -87,7 +85,7 @@ def apply_sdxl_lora_text_encoder( cls, text_encoder: CLIPTextModel, loras: List[Tuple[LoRAModelRaw, float]], - ): + ) -> None: with cls.apply_lora(text_encoder, loras, "lora_te1_"): yield @@ -97,7 +95,7 @@ def apply_sdxl_lora_text_encoder2( cls, text_encoder: CLIPTextModel, loras: List[Tuple[LoRAModelRaw, float]], - ): + ) -> None: with cls.apply_lora(text_encoder, loras, "lora_te2_"): yield @@ -105,10 +103,10 @@ def apply_sdxl_lora_text_encoder2( @contextmanager def apply_lora( cls, - model: Union[torch.nn.Module, ModelMixin, UNet2DConditionModel], - loras: List[Tuple[LoRAModelRaw, float]], + model: AnyModel, + loras: Iterator[Tuple[LoRAModelRaw, float]], prefix: str, - ) -> Generator[None, None, None]: + ) -> None: original_weights = {} try: with torch.no_grad(): @@ -125,6 +123,7 @@ def apply_lora( # 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the # intricacies of Stable Diffusion key resolution. It should just expect the input LoRA # weights to have valid keys. + assert isinstance(model, torch.nn.Module) module_key, module = cls._resolve_lora_key(model, layer_key, prefix) # All of the LoRA weight calculations will be done on the same device as the module weight. @@ -170,8 +169,8 @@ def apply_ti( cls, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, - ti_list: List[Tuple[str, TextualInversionModel]], - ) -> Generator[Tuple[CLIPTokenizer, TextualInversionManager], None, None]: + ti_list: List[Tuple[str, TextualInversionModelRaw]], + ) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]: init_tokens_count = None new_tokens_added = None @@ -201,7 +200,7 @@ def _get_trigger(ti_name: str, index: int) -> str: trigger += f"-!pad-{i}" return f"<{trigger}>" - def _get_ti_embedding(model_embeddings: torch.nn.Module, ti: TextualInversionModel) -> torch.Tensor: + def _get_ti_embedding(model_embeddings: torch.nn.Module, ti: TextualInversionModelRaw) -> torch.Tensor: # for SDXL models, select the embedding that matches the text encoder's dimensions if ti.embedding_2 is not None: return ( @@ -229,6 +228,7 @@ def _get_ti_embedding(model_embeddings: torch.nn.Module, ti: TextualInversionMod model_embeddings = text_encoder.get_input_embeddings() for ti_name, ti in ti_list: + assert isinstance(ti, TextualInversionModelRaw) ti_embedding = _get_ti_embedding(text_encoder.get_input_embeddings(), ti) ti_tokens = [] @@ -267,7 +267,7 @@ def apply_clip_skip( cls, text_encoder: CLIPTextModel, clip_skip: int, - ) -> Generator[None, None, None]: + ) -> None: skipped_layers = [] try: for _i in range(clip_skip): @@ -285,7 +285,7 @@ def apply_freeu( cls, unet: UNet2DConditionModel, freeu_config: Optional[FreeUConfig] = None, - ) -> Generator[None, None, None]: + ) -> None: did_apply_freeu = False try: assert hasattr(unet, "enable_freeu") # mypy doesn't pick up this attribute? @@ -301,94 +301,6 @@ def apply_freeu( unet.disable_freeu() -class TextualInversionModel: - embedding: torch.Tensor # [n, 768]|[n, 1280] - embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models - - @classmethod - def from_checkpoint( - cls, - file_path: Union[str, Path], - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ) -> Self: - if not isinstance(file_path, Path): - file_path = Path(file_path) - - result = cls() # TODO: - - if file_path.suffix == ".safetensors": - state_dict = load_file(file_path.absolute().as_posix(), device="cpu") - else: - state_dict = torch.load(file_path, map_location="cpu") - - # both v1 and v2 format embeddings - # difference mostly in metadata - if "string_to_param" in state_dict: - if len(state_dict["string_to_param"]) > 1: - print( - f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first', - " token will be used.", - ) - - result.embedding = next(iter(state_dict["string_to_param"].values())) - - # v3 (easynegative) - elif "emb_params" in state_dict: - result.embedding = state_dict["emb_params"] - - # v5(sdxl safetensors file) - elif "clip_g" in state_dict and "clip_l" in state_dict: - result.embedding = state_dict["clip_g"] - result.embedding_2 = state_dict["clip_l"] - - # v4(diffusers bin files) - else: - result.embedding = next(iter(state_dict.values())) - - if len(result.embedding.shape) == 1: - result.embedding = result.embedding.unsqueeze(0) - - if not isinstance(result.embedding, torch.Tensor): - raise ValueError(f"Invalid embeddings file: {file_path.name}") - - return result - - -# no type hints for BaseTextualInversionManager? -class TextualInversionManager(BaseTextualInversionManager): # type: ignore - pad_tokens: Dict[int, List[int]] - tokenizer: CLIPTokenizer - - def __init__(self, tokenizer: CLIPTokenizer): - self.pad_tokens = {} - self.tokenizer = tokenizer - - def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]: - if len(self.pad_tokens) == 0: - return token_ids - - if token_ids[0] == self.tokenizer.bos_token_id: - raise ValueError("token_ids must not start with bos_token_id") - if token_ids[-1] == self.tokenizer.eos_token_id: - raise ValueError("token_ids must not end with eos_token_id") - - new_token_ids = [] - for token_id in token_ids: - new_token_ids.append(token_id) - if token_id in self.pad_tokens: - new_token_ids.extend(self.pad_tokens[token_id]) - - # Do not exceed the max model input size - # The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(), - # which first removes and then adds back the start and end tokens. - max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2 - if len(new_token_ids) > max_length: - new_token_ids = new_token_ids[0:max_length] - - return new_token_ids - - class ONNXModelPatcher: @classmethod @contextmanager @@ -396,7 +308,7 @@ def apply_lora_unet( cls, unet: OnnxRuntimeModel, loras: List[Tuple[LoRAModelRaw, float]], - ) -> Generator[None, None, None]: + ) -> None: with cls.apply_lora(unet, loras, "lora_unet_"): yield @@ -406,7 +318,7 @@ def apply_lora_text_encoder( cls, text_encoder: OnnxRuntimeModel, loras: List[Tuple[LoRAModelRaw, float]], - ) -> Generator[None, None, None]: + ) -> None: with cls.apply_lora(text_encoder, loras, "lora_te_"): yield @@ -419,7 +331,7 @@ def apply_lora( model: IAIOnnxRuntimeModel, loras: List[Tuple[LoRAModelRaw, float]], prefix: str, - ) -> Generator[None, None, None]: + ) -> None: from .models.base import IAIOnnxRuntimeModel if not isinstance(model, IAIOnnxRuntimeModel): @@ -506,7 +418,7 @@ def apply_ti( tokenizer: CLIPTokenizer, text_encoder: IAIOnnxRuntimeModel, ti_list: List[Tuple[str, Any]], - ) -> Generator[Tuple[CLIPTokenizer, TextualInversionManager], None, None]: + ) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]: from .models.base import IAIOnnxRuntimeModel if not isinstance(text_encoder, IAIOnnxRuntimeModel): diff --git a/invokeai/backend/embeddings/textual_inversion.py b/invokeai/backend/embeddings/textual_inversion.py new file mode 100644 index 00000000000..389edff039d --- /dev/null +++ b/invokeai/backend/embeddings/textual_inversion.py @@ -0,0 +1,100 @@ +"""Textual Inversion wrapper class.""" + +from pathlib import Path +from typing import Dict, List, Optional, Union + +import torch +from compel.embeddings_provider import BaseTextualInversionManager +from safetensors.torch import load_file +from transformers import CLIPTokenizer +from typing_extensions import Self + +from .embedding_base import EmbeddingModelRaw + + +class TextualInversionModelRaw(EmbeddingModelRaw): + embedding: torch.Tensor # [n, 768]|[n, 1280] + embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models + + @classmethod + def from_checkpoint( + cls, + file_path: Union[str, Path], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> Self: + if not isinstance(file_path, Path): + file_path = Path(file_path) + + result = cls() # TODO: + + if file_path.suffix == ".safetensors": + state_dict = load_file(file_path.absolute().as_posix(), device="cpu") + else: + state_dict = torch.load(file_path, map_location="cpu") + + # both v1 and v2 format embeddings + # difference mostly in metadata + if "string_to_param" in state_dict: + if len(state_dict["string_to_param"]) > 1: + print( + f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first', + " token will be used.", + ) + + result.embedding = next(iter(state_dict["string_to_param"].values())) + + # v3 (easynegative) + elif "emb_params" in state_dict: + result.embedding = state_dict["emb_params"] + + # v5(sdxl safetensors file) + elif "clip_g" in state_dict and "clip_l" in state_dict: + result.embedding = state_dict["clip_g"] + result.embedding_2 = state_dict["clip_l"] + + # v4(diffusers bin files) + else: + result.embedding = next(iter(state_dict.values())) + + if len(result.embedding.shape) == 1: + result.embedding = result.embedding.unsqueeze(0) + + if not isinstance(result.embedding, torch.Tensor): + raise ValueError(f"Invalid embeddings file: {file_path.name}") + + return result + + +# no type hints for BaseTextualInversionManager? +class TextualInversionManager(BaseTextualInversionManager): # type: ignore + pad_tokens: Dict[int, List[int]] + tokenizer: CLIPTokenizer + + def __init__(self, tokenizer: CLIPTokenizer): + self.pad_tokens = {} + self.tokenizer = tokenizer + + def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]: + if len(self.pad_tokens) == 0: + return token_ids + + if token_ids[0] == self.tokenizer.bos_token_id: + raise ValueError("token_ids must not start with bos_token_id") + if token_ids[-1] == self.tokenizer.eos_token_id: + raise ValueError("token_ids must not end with eos_token_id") + + new_token_ids = [] + for token_id in token_ids: + new_token_ids.append(token_id) + if token_id in self.pad_tokens: + new_token_ids.extend(self.pad_tokens[token_id]) + + # Do not exceed the max model input size + # The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(), + # which first removes and then adds back the start and end tokens. + max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2 + if len(new_token_ids) > max_length: + new_token_ids = new_token_ids[0:max_length] + + return new_token_ids diff --git a/invokeai/backend/install/install_helper.py b/invokeai/backend/install/install_helper.py index 57dfadcaeaa..8877e33092c 100644 --- a/invokeai/backend/install/install_helper.py +++ b/invokeai/backend/install/install_helper.py @@ -241,10 +241,11 @@ def _make_install_source(self, model_info: UnifiedModelInfo) -> ModelSource: if match := re.match(f"^([^/]+/[^/]+?)(?::({variants}))?$", model_path_id_or_url): repo_id = match.group(1) repo_variant = ModelRepoVariant(match.group(2)) if match.group(2) else None + subfolder = Path(model_info.subfolder) if model_info.subfolder else None return HFModelSource( repo_id=repo_id, access_token=HfFolder.get_token(), - subfolder=model_info.subfolder, + subfolder=subfolder, variant=repo_variant, ) if re.match(r"^(http|https):", model_path_id_or_url): diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 49ce6af2b81..0dcd925c84b 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -30,8 +30,11 @@ from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel +from ..embeddings.embedding_base import EmbeddingModelRaw from ..ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus +AnyModel = Union[ModelMixin, torch.nn.Module, IAIOnnxRuntimeModel, IPAdapter, IPAdapterPlus, EmbeddingModelRaw] + class InvalidModelConfigException(Exception): """Exception for when config parser doesn't recognized this combination of model type and format.""" @@ -299,7 +302,7 @@ class T2IConfig(ModelConfigBase): ] AnyModelConfigValidator = TypeAdapter(AnyModelConfig) -AnyModel = Union[ModelMixin, torch.nn.Module, IAIOnnxRuntimeModel, IPAdapter, IPAdapterPlus] + # IMPLEMENTATION NOTE: # The preferred alternative to the above is a discriminated Union as shown diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index ee9d6d53e3d..9d98ee30531 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -18,8 +18,8 @@ from typing import Any, Callable, Dict, Optional, Tuple, Type from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType -from invokeai.backend.model_manager.config import VaeCheckpointConfig, VaeDiffusersConfig +from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType +from invokeai.backend.model_manager.config import AnyModel, VaeCheckpointConfig, VaeDiffusersConfig from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase from invokeai.backend.util.logging import InvokeAILogger diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index 757745072d1..2192c88ac2f 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -19,7 +19,7 @@ ) from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase -from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase +from invokeai.backend.model_manager.load.model_cache.model_cache_base import CacheStats, ModelCacheBase, ModelLockerBase from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init from invokeai.backend.util.devices import choose_torch_device, torch_dtype @@ -71,7 +71,7 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type) if not model_path.exists(): - raise InvalidModelConfigException(f"Files for model 'model_config.name' not found at {model_path}") + raise InvalidModelConfigException(f"Files for model '{model_config.name}' not found at {model_path}") model_path = self._convert_if_needed(model_config, model_path, submodel_type) locker = self._load_if_needed(model_config, model_path, submodel_type) diff --git a/invokeai/backend/model_manager/load/model_cache/__init__.py b/invokeai/backend/model_manager/load/model_cache/__init__.py index 0cb5184f3a4..32c682d0424 100644 --- a/invokeai/backend/model_manager/load/model_cache/__init__.py +++ b/invokeai/backend/model_manager/load/model_cache/__init__.py @@ -1,4 +1,6 @@ """Init file for ModelCache.""" +from .model_cache_base import ModelCacheBase, CacheStats # noqa F401 +from .model_cache_default import ModelCache # noqa F401 -_all__ = ["ModelCacheBase", "ModelCache"] +_all__ = ["ModelCacheBase", "ModelCache", "CacheStats"] diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py index b1a6768ee8f..4a4a3c7d299 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py @@ -8,13 +8,13 @@ """ from abc import ABC, abstractmethod -from dataclasses import dataclass +from dataclasses import dataclass, field from logging import Logger -from typing import Generic, Optional, TypeVar +from typing import Dict, Generic, Optional, TypeVar import torch -from invokeai.backend.model_manager import AnyModel, SubModelType +from invokeai.backend.model_manager.config import AnyModel, SubModelType class ModelLockerBase(ABC): @@ -65,6 +65,19 @@ def locked(self) -> bool: return self._locks > 0 +@dataclass +class CacheStats(object): + """Collect statistics on cache performance.""" + + hits: int = 0 # cache hits + misses: int = 0 # cache misses + high_watermark: int = 0 # amount of cache used + in_cache: int = 0 # number of models in cache + cleared: int = 0 # number of models cleared to make space + cache_size: int = 0 # total size of cache + loaded_model_sizes: Dict[str, int] = field(default_factory=dict) + + class ModelCacheBase(ABC, Generic[T]): """Virtual base class for RAM model cache.""" @@ -98,10 +111,22 @@ def offload_unlocked_models(self, size_required: int) -> None: pass @abstractmethod - def move_model_to_device(self, cache_entry: CacheRecord, device: torch.device) -> None: + def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], device: torch.device) -> None: """Move model into the indicated device.""" pass + @property + @abstractmethod + def stats(self) -> CacheStats: + """Return collected CacheStats object.""" + pass + + @stats.setter + @abstractmethod + def stats(self, stats: CacheStats) -> None: + """Set the CacheStats object for collectin cache statistics.""" + pass + @property @abstractmethod def logger(self) -> Logger: diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index 7e30512a588..b1deb215b2b 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -24,19 +24,17 @@ import sys import time from contextlib import suppress -from dataclasses import dataclass, field from logging import Logger from typing import Dict, List, Optional import torch -from invokeai.backend.model_manager import SubModelType -from invokeai.backend.model_manager.load.load_base import AnyModel +from invokeai.backend.model_manager import AnyModel, SubModelType from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff from invokeai.backend.util.devices import choose_torch_device from invokeai.backend.util.logging import InvokeAILogger -from .model_cache_base import CacheRecord, ModelCacheBase +from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase from .model_locker import ModelLocker, ModelLockerBase if choose_torch_device() == torch.device("mps"): @@ -56,20 +54,6 @@ MB = 2**20 -@dataclass -class CacheStats(object): - """Collect statistics on cache performance.""" - - hits: int = 0 # cache hits - misses: int = 0 # cache misses - high_watermark: int = 0 # amount of cache used - in_cache: int = 0 # number of models in cache - cleared: int = 0 # number of models cleared to make space - cache_size: int = 0 # total size of cache - # {submodel_key => size} - loaded_model_sizes: Dict[str, int] = field(default_factory=dict) - - class ModelCache(ModelCacheBase[AnyModel]): """Implementation of ModelCacheBase.""" @@ -110,7 +94,7 @@ def __init__( self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__) self._log_memory_usage = log_memory_usage or self._logger.level == logging.DEBUG # used for stats collection - self.stats = CacheStats() + self._stats: Optional[CacheStats] = None self._cached_models: Dict[str, CacheRecord[AnyModel]] = {} self._cache_stack: List[str] = [] @@ -140,6 +124,16 @@ def max_cache_size(self) -> float: """Return the cap on cache size.""" return self._max_cache_size + @property + def stats(self) -> Optional[CacheStats]: + """Return collected CacheStats object.""" + return self._stats + + @stats.setter + def stats(self, stats: CacheStats) -> None: + """Set the CacheStats object for collectin cache statistics.""" + self._stats = stats + def cache_size(self) -> int: """Get the total size of the models currently cached.""" total = 0 @@ -189,21 +183,24 @@ def get( """ key = self._make_cache_key(key, submodel_type) if key in self._cached_models: - self.stats.hits += 1 + if self.stats: + self.stats.hits += 1 else: - self.stats.misses += 1 + if self.stats: + self.stats.misses += 1 raise IndexError(f"The model with key {key} is not in the cache.") cache_entry = self._cached_models[key] # more stats - stats_name = stats_name or key - self.stats.cache_size = int(self._max_cache_size * GIG) - self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size()) - self.stats.in_cache = len(self._cached_models) - self.stats.loaded_model_sizes[stats_name] = max( - self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size - ) + if self.stats: + stats_name = stats_name or key + self.stats.cache_size = int(self._max_cache_size * GIG) + self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size()) + self.stats.in_cache = len(self._cached_models) + self.stats.loaded_model_sizes[stats_name] = max( + self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size + ) # this moves the entry to the top (right end) of the stack with suppress(Exception): diff --git a/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py b/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py index 394fddc75d0..6635f6b43fe 100644 --- a/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py +++ b/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py @@ -5,7 +5,7 @@ from pathlib import Path from typing import Optional, Tuple -from invokeai.backend.embeddings.model_patcher import TextualInversionModel as TextualInversionModelRaw +from invokeai.backend.embeddings.textual_inversion import TextualInversionModelRaw from invokeai.backend.model_manager import ( AnyModel, AnyModelConfig, diff --git a/invokeai/backend/stable_diffusion/__init__.py b/invokeai/backend/stable_diffusion/__init__.py index 212045f81b8..75e6aa0a5de 100644 --- a/invokeai/backend/stable_diffusion/__init__.py +++ b/invokeai/backend/stable_diffusion/__init__.py @@ -4,3 +4,12 @@ from .diffusers_pipeline import PipelineIntermediateState, StableDiffusionGeneratorPipeline # noqa: F401 from .diffusion import InvokeAIDiffuserComponent # noqa: F401 from .diffusion.cross_attention_map_saving import AttentionMapSaver # noqa: F401 +from .seamless import set_seamless # noqa: F401 + +__all__ = [ + "PipelineIntermediateState", + "StableDiffusionGeneratorPipeline", + "InvokeAIDiffuserComponent", + "AttentionMapSaver", + "set_seamless", +] diff --git a/invokeai/backend/stable_diffusion/seamless.py b/invokeai/backend/stable_diffusion/seamless.py new file mode 100644 index 00000000000..bfdf9e0c536 --- /dev/null +++ b/invokeai/backend/stable_diffusion/seamless.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +from contextlib import contextmanager +from typing import List, Union + +import torch.nn as nn +from diffusers.models import AutoencoderKL, UNet2DConditionModel + + +def _conv_forward_asymmetric(self, input, weight, bias): + """ + Patch for Conv2d._conv_forward that supports asymmetric padding + """ + working = nn.functional.pad(input, self.asymmetric_padding["x"], mode=self.asymmetric_padding_mode["x"]) + working = nn.functional.pad(working, self.asymmetric_padding["y"], mode=self.asymmetric_padding_mode["y"]) + return nn.functional.conv2d( + working, + weight, + bias, + self.stride, + nn.modules.utils._pair(0), + self.dilation, + self.groups, + ) + + +@contextmanager +def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]): + try: + to_restore = [] + + for m_name, m in model.named_modules(): + if isinstance(model, UNet2DConditionModel): + if ".attentions." in m_name: + continue + + if ".resnets." in m_name: + if ".conv2" in m_name: + continue + if ".conv_shortcut" in m_name: + continue + + """ + if isinstance(model, UNet2DConditionModel): + if False and ".upsamplers." in m_name: + continue + + if False and ".downsamplers." in m_name: + continue + + if True and ".resnets." in m_name: + if True and ".conv1" in m_name: + if False and "down_blocks" in m_name: + continue + if False and "mid_block" in m_name: + continue + if False and "up_blocks" in m_name: + continue + + if True and ".conv2" in m_name: + continue + + if True and ".conv_shortcut" in m_name: + continue + + if True and ".attentions." in m_name: + continue + + if False and m_name in ["conv_in", "conv_out"]: + continue + """ + + if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): + m.asymmetric_padding_mode = {} + m.asymmetric_padding = {} + m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant" + m.asymmetric_padding["x"] = ( + m._reversed_padding_repeated_twice[0], + m._reversed_padding_repeated_twice[1], + 0, + 0, + ) + m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant" + m.asymmetric_padding["y"] = ( + 0, + 0, + m._reversed_padding_repeated_twice[2], + m._reversed_padding_repeated_twice[3], + ) + + to_restore.append((m, m._conv_forward)) + m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d) + + yield + + finally: + for module, orig_conv_forward in to_restore: + module._conv_forward = orig_conv_forward + if hasattr(module, "asymmetric_padding_mode"): + del module.asymmetric_padding_mode + if hasattr(module, "asymmetric_padding"): + del module.asymmetric_padding diff --git a/invokeai/backend/util/silence_warnings.py b/invokeai/backend/util/silence_warnings.py new file mode 100644 index 00000000000..068b605da97 --- /dev/null +++ b/invokeai/backend/util/silence_warnings.py @@ -0,0 +1,28 @@ +"""Context class to silence transformers and diffusers warnings.""" +import warnings +from typing import Any + +from diffusers import logging as diffusers_logging +from transformers import logging as transformers_logging + + +class SilenceWarnings(object): + """Use in context to temporarily turn off warnings from transformers & diffusers modules. + + with SilenceWarnings(): + # do something + """ + + def __init__(self) -> None: + self.transformers_verbosity = transformers_logging.get_verbosity() + self.diffusers_verbosity = diffusers_logging.get_verbosity() + + def __enter__(self) -> None: + transformers_logging.set_verbosity_error() + diffusers_logging.set_verbosity_error() + warnings.simplefilter("ignore") + + def __exit__(self, *args: Any) -> None: + transformers_logging.set_verbosity(self.transformers_verbosity) + diffusers_logging.set_verbosity(self.diffusers_verbosity) + warnings.simplefilter("default") diff --git a/invokeai/frontend/install/model_install2.py b/invokeai/frontend/install/model_install2.py index 51a633a5654..22b132370e6 100644 --- a/invokeai/frontend/install/model_install2.py +++ b/invokeai/frontend/install/model_install2.py @@ -23,7 +23,7 @@ from npyscreen import widget from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.model_install import ModelInstallService +from invokeai.app.services.model_install import ModelInstallServiceBase from invokeai.backend.install.install_helper import InstallHelper, InstallSelections, UnifiedModelInfo from invokeai.backend.model_manager import ModelType from invokeai.backend.util import choose_precision, choose_torch_device @@ -499,7 +499,7 @@ def onStart(self) -> None: ) -def list_models(installer: ModelInstallService, model_type: ModelType): +def list_models(installer: ModelInstallServiceBase, model_type: ModelType): """Print out all models of type model_type.""" models = installer.record_store.search_by_attr(model_type=model_type) print(f"Installed models of type `{model_type}`:") @@ -527,7 +527,9 @@ def select_and_download_models(opt: Namespace) -> None: install_helper.add_or_delete(selections) elif opt.default_only: - selections = InstallSelections(install_models=[install_helper.default_model()]) + default_model = install_helper.default_model() + assert default_model is not None + selections = InstallSelections(install_models=[default_model]) install_helper.add_or_delete(selections) elif opt.yes_to_all: diff --git a/tests/backend/model_manager_2/util/test_hf_model_select.py b/tests/backend/model_manager_2/util/test_hf_model_select.py index f14d9a6823a..5bef9cb2e19 100644 --- a/tests/backend/model_manager_2/util/test_hf_model_select.py +++ b/tests/backend/model_manager_2/util/test_hf_model_select.py @@ -192,6 +192,7 @@ def sdxl_base_files() -> List[Path]: "text_encoder/model.onnx", "text_encoder_2/config.json", "text_encoder_2/model.onnx", + "text_encoder_2/model.onnx_data", "tokenizer/merges.txt", "tokenizer/special_tokens_map.json", "tokenizer/tokenizer_config.json", @@ -202,6 +203,7 @@ def sdxl_base_files() -> List[Path]: "tokenizer_2/vocab.json", "unet/config.json", "unet/model.onnx", + "unet/model.onnx_data", "vae_decoder/config.json", "vae_decoder/model.onnx", "vae_encoder/config.json", diff --git a/tests/test_model_probe.py b/tests/test_model_probe.py index aacae06a8bb..be823e2be9f 100644 --- a/tests/test_model_probe.py +++ b/tests/test_model_probe.py @@ -2,7 +2,7 @@ import pytest -from invokeai.backend import BaseModelType +from invokeai.backend.model_manager import BaseModelType, ModelRepoVariant from invokeai.backend.model_manager.probe import VaeFolderProbe @@ -21,10 +21,10 @@ def test_get_base_type(vae_path: str, expected_type: BaseModelType, datadir: Pat base_type = probe.get_base_type() assert base_type == expected_type repo_variant = probe.get_repo_variant() - assert repo_variant == "default" + assert repo_variant == ModelRepoVariant.DEFAULT def test_repo_variant(datadir: Path): probe = VaeFolderProbe(datadir / "vae" / "taesdxl-fp16") repo_variant = probe.get_repo_variant() - assert repo_variant == "fp16" + assert repo_variant == ModelRepoVariant.FP16 From d959276217b727c2aaaf6fdb8c24e5dbfce3f77a Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Fri, 9 Feb 2024 16:42:33 -0500 Subject: [PATCH 079/100] fix invokeai_configure script to work with new mm; rename CLIs --- .../app/services/config/config_default.py | 10 +- invokeai/backend/install/install_helper.py | 2 +- .../backend/install/invokeai_configure.py | 197 ++++---- .../model_manager/load/load_default.py | 2 +- invokeai/backend/util/devices.py | 6 +- invokeai/configs/INITIAL_MODELS.yaml | 106 +++-- ...L_MODELS2.yaml => INITIAL_MODELS.yaml.OLD} | 106 ++--- invokeai/frontend/install/model_install.py | 448 +++++------------- ...model_install2.py => model_install.py.OLD} | 448 +++++++++++++----- invokeai/frontend/install/widgets.py | 11 + ...e_diffusers2.py => merge_diffusers.py.OLD} | 0 pyproject.toml | 3 +- tests/test_model_manager.py | 47 -- 13 files changed, 690 insertions(+), 696 deletions(-) rename invokeai/configs/{INITIAL_MODELS2.yaml => INITIAL_MODELS.yaml.OLD} (59%) rename invokeai/frontend/install/{model_install2.py => model_install.py.OLD} (57%) rename invokeai/frontend/merge/{merge_diffusers2.py => merge_diffusers.py.OLD} (100%) delete mode 100644 tests/test_model_manager.py diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index b39e916da34..2af775372dd 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -185,7 +185,9 @@ class InvokeBatch(InvokeAISettings): INIT_FILE = Path("invokeai.yaml") DB_FILE = Path("invokeai.db") LEGACY_INIT_FILE = Path("invokeai.init") -DEFAULT_MAX_VRAM = 0.5 +DEFAULT_RAM_CACHE = 10.0 +DEFAULT_VRAM_CACHE = 0.25 +DEFAULT_CONVERT_CACHE = 20.0 class Categories(object): @@ -261,9 +263,9 @@ class InvokeAIAppConfig(InvokeAISettings): version : bool = Field(default=False, description="Show InvokeAI version and exit", json_schema_extra=Categories.Other) # CACHE - ram : float = Field(default=7.5, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", json_schema_extra=Categories.ModelCache, ) - vram : float = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", json_schema_extra=Categories.ModelCache, ) - convert_cache : float = Field(default=10.0, ge=0, description="Maximum size of on-disk converted models cache (GB)", json_schema_extra=Categories.ModelCache) + ram : float = Field(default=DEFAULT_RAM_CACHE, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", json_schema_extra=Categories.ModelCache, ) + vram : float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", json_schema_extra=Categories.ModelCache, ) + convert_cache : float = Field(default=DEFAULT_CONVERT_CACHE, ge=0, description="Maximum size of on-disk converted models cache (GB)", json_schema_extra=Categories.ModelCache) lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", json_schema_extra=Categories.ModelCache, ) log_memory_usage : bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.", json_schema_extra=Categories.ModelCache) diff --git a/invokeai/backend/install/install_helper.py b/invokeai/backend/install/install_helper.py index 8877e33092c..9c386c209ce 100644 --- a/invokeai/backend/install/install_helper.py +++ b/invokeai/backend/install/install_helper.py @@ -37,7 +37,7 @@ from invokeai.backend.util.logging import InvokeAILogger # name of the starter models file -INITIAL_MODELS = "INITIAL_MODELS2.yaml" +INITIAL_MODELS = "INITIAL_MODELS.yaml" def initialize_record_store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase: diff --git a/invokeai/backend/install/invokeai_configure.py b/invokeai/backend/install/invokeai_configure.py index 3cb7db6c82c..4dfa2b070c0 100755 --- a/invokeai/backend/install/invokeai_configure.py +++ b/invokeai/backend/install/invokeai_configure.py @@ -18,31 +18,30 @@ from enum import Enum from pathlib import Path from shutil import get_terminal_size -from typing import Any, get_args, get_type_hints +from typing import Any, Optional, Set, Tuple, Type, get_args, get_type_hints from urllib import request import npyscreen -import omegaconf import psutil import torch import transformers -import yaml -from diffusers import AutoencoderKL +from diffusers import AutoencoderKL, ModelMixin from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from huggingface_hub import HfFolder from huggingface_hub import login as hf_hub_login -from omegaconf import OmegaConf -from pydantic import ValidationError +from omegaconf import DictConfig, OmegaConf +from pydantic.error_wrappers import ValidationError from tqdm import tqdm from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextConfig, CLIPTextModel, CLIPTokenizer import invokeai.configs as configs from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend.install.install_helper import InstallHelper, InstallSelections from invokeai.backend.install.legacy_arg_parsing import legacy_parser -from invokeai.backend.install.model_install_backend import InstallSelections, ModelInstall, hf_download_from_pretrained -from invokeai.backend.model_management.model_probe import BaseModelType, ModelType +from invokeai.backend.model_manager import BaseModelType, ModelType +from invokeai.backend.util import choose_precision, choose_torch_device from invokeai.backend.util.logging import InvokeAILogger -from invokeai.frontend.install.model_install import addModelsForm, process_and_execute +from invokeai.frontend.install.model_install import addModelsForm # TO DO - Move all the frontend code into invokeai.frontend.install from invokeai.frontend.install.widgets import ( @@ -61,7 +60,7 @@ transformers.logging.set_verbosity_error() -def get_literal_fields(field) -> list[Any]: +def get_literal_fields(field: str) -> Tuple[Any]: return get_args(get_type_hints(InvokeAIAppConfig).get(field)) @@ -80,8 +79,7 @@ def get_literal_fields(field) -> list[Any]: GENERATION_OPT_CHOICES = ["sequential_guidance", "force_tiled_decode", "lazy_offload"] GB = 1073741824 # GB in bytes HAS_CUDA = torch.cuda.is_available() -_, MAX_VRAM = torch.cuda.mem_get_info() if HAS_CUDA else (0, 0) - +_, MAX_VRAM = torch.cuda.mem_get_info() if HAS_CUDA else (0.0, 0.0) MAX_VRAM /= GB MAX_RAM = psutil.virtual_memory().total / GB @@ -96,13 +94,15 @@ def get_literal_fields(field) -> list[Any]: class DummyWidgetValue(Enum): + """Dummy widget values.""" + zero = 0 true = True false = False # -------------------------------------------- -def postscript(errors: None): +def postscript(errors: Set[str]) -> None: if not any(errors): message = f""" ** INVOKEAI INSTALLATION SUCCESSFUL ** @@ -143,7 +143,7 @@ def yes_or_no(prompt: str, default_yes=True): # --------------------------------------------- -def HfLogin(access_token) -> str: +def HfLogin(access_token) -> None: """ Helper for logging in to Huggingface The stdout capture is needed to hide the irrelevant "git credential helper" warning @@ -162,7 +162,7 @@ def HfLogin(access_token) -> str: # ------------------------------------- class ProgressBar: - def __init__(self, model_name="file"): + def __init__(self, model_name: str = "file"): self.pbar = None self.name = model_name @@ -179,6 +179,22 @@ def __call__(self, block_num, block_size, total_size): self.pbar.update(block_size) +# --------------------------------------------- +def hf_download_from_pretrained(model_class: Type[ModelMixin], model_name: str, destination: Path, **kwargs: Any): + filter = lambda x: "fp16 is not a valid" not in x.getMessage() # noqa E731 + logger.addFilter(filter) + try: + model = model_class.from_pretrained( + model_name, + resume_download=True, + **kwargs, + ) + model.save_pretrained(destination, safe_serialization=True) + finally: + logger.removeFilter(filter) + return destination + + # --------------------------------------------- def download_with_progress_bar(model_url: str, model_dest: str, label: str = "the"): try: @@ -249,6 +265,7 @@ def download_conversion_models(): # --------------------------------------------- +# TO DO: use the download queue here. def download_realesrgan(): logger.info("Installing ESRGAN Upscaling models...") URLs = [ @@ -288,18 +305,19 @@ def download_lama(): # --------------------------------------------- -def download_support_models(): +def download_support_models() -> None: download_realesrgan() download_lama() download_conversion_models() # ------------------------------------- -def get_root(root: str = None) -> str: +def get_root(root: Optional[str] = None) -> str: if root: return root - elif os.environ.get("INVOKEAI_ROOT"): - return os.environ.get("INVOKEAI_ROOT") + elif root := os.environ.get("INVOKEAI_ROOT"): + assert root is not None + return root else: return str(config.root_path) @@ -455,6 +473,25 @@ def create(self): max_width=110, scroll_exit=True, ) + self.add_widget_intelligent( + npyscreen.TitleFixedText, + name="Model disk conversion cache size (GB). This is used to cache safetensors files that need to be converted to diffusers..", + begin_entry_at=0, + editable=False, + color="CONTROL", + scroll_exit=True, + ) + self.nextrely -= 1 + self.disk = self.add_widget_intelligent( + npyscreen.Slider, + value=clip(old_opts.convert_cache, range=(0, 100), step=0.5), + out_of=100, + lowest=0.0, + step=0.5, + relx=8, + scroll_exit=True, + ) + self.nextrely += 1 self.add_widget_intelligent( npyscreen.TitleFixedText, name="Model RAM cache size (GB). Make this at least large enough to hold a single full model (2GB for SD-1, 6GB for SDXL).", @@ -495,6 +532,14 @@ def create(self): ) else: self.vram = DummyWidgetValue.zero + + self.nextrely += 1 + self.add_widget_intelligent( + npyscreen.FixedText, + value="Location of the database used to store model path and configuration information:", + editable=False, + color="CONTROL", + ) self.nextrely += 1 self.outdir = self.add_widget_intelligent( FileBox, @@ -506,19 +551,21 @@ def create(self): labelColor="GOOD", begin_entry_at=40, max_height=3, + max_width=127, scroll_exit=True, ) self.autoimport_dirs = {} self.autoimport_dirs["autoimport_dir"] = self.add_widget_intelligent( FileBox, - name="Folder to recursively scan for new checkpoints, ControlNets, LoRAs and TI models", - value=str(config.root_path / config.autoimport_dir), + name="Optional folder to scan for new checkpoints, ControlNets, LoRAs and TI models", + value=str(config.root_path / config.autoimport_dir) if config.autoimport_dir else "", select_dir=True, must_exist=False, use_two_lines=False, labelColor="GOOD", begin_entry_at=32, max_height=3, + max_width=127, scroll_exit=True, ) self.nextrely += 1 @@ -555,6 +602,10 @@ def show_hide_slice_sizes(self, value): self.attention_slice_label.hidden = not show self.attention_slice_size.hidden = not show + def show_hide_model_conf_override(self, value): + self.model_conf_override.hidden = value + self.model_conf_override.display() + def on_ok(self): options = self.marshall_arguments() if self.validate_field_values(options): @@ -584,18 +635,21 @@ def validate_field_values(self, opt: Namespace) -> bool: else: return True - def marshall_arguments(self): + def marshall_arguments(self) -> Namespace: new_opts = Namespace() for attr in [ "ram", "vram", + "convert_cache", "outdir", ]: if hasattr(self, attr): setattr(new_opts, attr, getattr(self, attr).value) for attr in self.autoimport_dirs: + if not self.autoimport_dirs[attr].value: + continue directory = Path(self.autoimport_dirs[attr].value) if directory.is_relative_to(config.root_path): directory = directory.relative_to(config.root_path) @@ -615,13 +669,14 @@ def marshall_arguments(self): class EditOptApplication(npyscreen.NPSAppManaged): - def __init__(self, program_opts: Namespace, invokeai_opts: Namespace): + def __init__(self, program_opts: Namespace, invokeai_opts: InvokeAIAppConfig, install_helper: InstallHelper): super().__init__() self.program_opts = program_opts self.invokeai_opts = invokeai_opts self.user_cancelled = False self.autoload_pending = True - self.install_selections = default_user_selections(program_opts) + self.install_helper = install_helper + self.install_selections = default_user_selections(program_opts, install_helper) def onStart(self): npyscreen.setTheme(npyscreen.Themes.DefaultTheme) @@ -640,16 +695,10 @@ def onStart(self): cycle_widgets=False, ) - def new_opts(self): + def new_opts(self) -> Namespace: return self.options.marshall_arguments() -def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Namespace: - editApp = EditOptApplication(program_opts, invokeai_opts) - editApp.run() - return editApp.new_opts() - - def default_ramcache() -> float: """Run a heuristic for the default RAM cache based on installed RAM.""" @@ -660,27 +709,18 @@ def default_ramcache() -> float: ) # 2.1 is just large enough for sd 1.5 ;-) -def default_startup_options(init_file: Path) -> Namespace: +def default_startup_options(init_file: Path) -> InvokeAIAppConfig: opts = InvokeAIAppConfig.get_config() - opts.ram = opts.ram or default_ramcache() + opts.ram = default_ramcache() return opts -def default_user_selections(program_opts: Namespace) -> InstallSelections: - try: - installer = ModelInstall(config) - except omegaconf.errors.ConfigKeyError: - logger.warning("Your models.yaml file is corrupt or out of date. Reinitializing") - initialize_rootdir(config.root_path, True) - installer = ModelInstall(config) - - models = installer.all_models() +def default_user_selections(program_opts: Namespace, install_helper: InstallHelper) -> InstallSelections: + default_model = install_helper.default_model() + assert default_model is not None + default_models = [default_model] if program_opts.default_only else install_helper.recommended_models() return InstallSelections( - install_models=[models[installer.default_model()].path or models[installer.default_model()].repo_id] - if program_opts.default_only - else [models[x].path or models[x].repo_id for x in installer.recommended_models()] - if program_opts.yes_to_all - else [], + install_models=default_models if program_opts.yes_to_all else [], ) @@ -716,21 +756,10 @@ def initialize_rootdir(root: Path, yes_to_all: bool = False): path.mkdir(parents=True, exist_ok=True) -def maybe_create_models_yaml(root: Path): - models_yaml = root / "configs" / "models.yaml" - if models_yaml.exists(): - if OmegaConf.load(models_yaml).get("__metadata__"): # up to date - return - else: - logger.info("Creating new models.yaml, original saved as models.yaml.orig") - models_yaml.rename(models_yaml.parent / "models.yaml.orig") - - with open(models_yaml, "w") as yaml_file: - yaml_file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}})) - - # ------------------------------------- -def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace, Namespace): +def run_console_ui( + program_opts: Namespace, initfile: Path, install_helper: InstallHelper +) -> Tuple[Optional[Namespace], Optional[InstallSelections]]: invokeai_opts = default_startup_options(initfile) invokeai_opts.root = program_opts.root @@ -739,22 +768,16 @@ def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace "Could not increase terminal size. Try running again with a larger window or smaller font size." ) - # the install-models application spawns a subprocess to install - # models, and will crash unless this is set before running. - import torch - - torch.multiprocessing.set_start_method("spawn") - - editApp = EditOptApplication(program_opts, invokeai_opts) + editApp = EditOptApplication(program_opts, invokeai_opts, install_helper) editApp.run() if editApp.user_cancelled: return (None, None) else: - return (editApp.new_opts, editApp.install_selections) + return (editApp.new_opts(), editApp.install_selections) # ------------------------------------- -def write_opts(opts: Namespace, init_file: Path): +def write_opts(opts: InvokeAIAppConfig, init_file: Path) -> None: """ Update the invokeai.yaml file with values from current settings. """ @@ -762,7 +785,7 @@ def write_opts(opts: Namespace, init_file: Path): new_config = InvokeAIAppConfig.get_config() new_config.root = config.root - for key, value in opts.__dict__.items(): + for key, value in opts.model_dump().items(): if hasattr(new_config, key): setattr(new_config, key, value) @@ -779,7 +802,7 @@ def default_output_dir() -> Path: # ------------------------------------- -def write_default_options(program_opts: Namespace, initfile: Path): +def write_default_options(program_opts: Namespace, initfile: Path) -> None: opt = default_startup_options(initfile) write_opts(opt, initfile) @@ -789,16 +812,11 @@ def write_default_options(program_opts: Namespace, initfile: Path): # the legacy Args object in order to parse # the old init file and write out the new # yaml format. -def migrate_init_file(legacy_format: Path): +def migrate_init_file(legacy_format: Path) -> None: old = legacy_parser.parse_args([f"@{str(legacy_format)}"]) new = InvokeAIAppConfig.get_config() - fields = [ - x - for x, y in InvokeAIAppConfig.model_fields.items() - if (y.json_schema_extra.get("category", None) if y.json_schema_extra else None) != "DEPRECATED" - ] - for attr in fields: + for attr in InvokeAIAppConfig.model_fields.keys(): if hasattr(old, attr): try: setattr(new, attr, getattr(old, attr)) @@ -819,7 +837,7 @@ def migrate_init_file(legacy_format: Path): # ------------------------------------- -def migrate_models(root: Path): +def migrate_models(root: Path) -> None: from invokeai.backend.install.migrate_to_3 import do_migrate do_migrate(root, root) @@ -838,7 +856,9 @@ def migrate_if_needed(opt: Namespace, root: Path) -> bool: ): logger.info("** Migrating invokeai.init to invokeai.yaml") migrate_init_file(old_init_file) - config.parse_args(argv=[], conf=OmegaConf.load(new_init_file)) + omegaconf = OmegaConf.load(new_init_file) + assert isinstance(omegaconf, DictConfig) + config.parse_args(argv=[], conf=omegaconf) if old_hub.exists(): migrate_models(config.root_path) @@ -849,7 +869,7 @@ def migrate_if_needed(opt: Namespace, root: Path) -> bool: # ------------------------------------- -def main() -> None: +def main(): parser = argparse.ArgumentParser(description="InvokeAI model downloader") parser.add_argument( "--skip-sd-weights", @@ -908,6 +928,7 @@ def main() -> None: if opt.full_precision: invoke_args.extend(["--precision", "float32"]) config.parse_args(invoke_args) + config.precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device())) logger = InvokeAILogger().get_logger(config=config) errors = set() @@ -921,14 +942,18 @@ def main() -> None: # run this unconditionally in case new directories need to be added initialize_rootdir(config.root_path, opt.yes_to_all) - models_to_download = default_user_selections(opt) + # this will initialize the models.yaml file if not present + install_helper = InstallHelper(config, logger) + + models_to_download = default_user_selections(opt, install_helper) new_init_file = config.root_path / "invokeai.yaml" if opt.yes_to_all: write_default_options(opt, new_init_file) init_options = Namespace(precision="float32" if opt.full_precision else "float16") + else: - init_options, models_to_download = run_console_ui(opt, new_init_file) + init_options, models_to_download = run_console_ui(opt, new_init_file, install_helper) if init_options: write_opts(init_options, new_init_file) else: @@ -943,10 +968,12 @@ def main() -> None: if opt.skip_sd_weights: logger.warning("Skipping diffusion weights download per user request") + elif models_to_download: - process_and_execute(opt, models_to_download) + install_helper.add_or_delete(models_to_download) postscript(errors=errors) + if not opt.yes_to_all: input("Press any key to continue...") except WindowTooSmallException as e: diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index 2192c88ac2f..c1dfe729af7 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -19,7 +19,7 @@ ) from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase -from invokeai.backend.model_manager.load.model_cache.model_cache_base import CacheStats, ModelCacheBase, ModelLockerBase +from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init from invokeai.backend.util.devices import choose_torch_device, torch_dtype diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index b4f24d8483b..a83d1045f70 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -1,7 +1,7 @@ from __future__ import annotations from contextlib import nullcontext -from typing import Optional, Union +from typing import Literal, Optional, Union import torch from torch import autocast @@ -31,7 +31,9 @@ def choose_torch_device() -> torch.device: # We are in transition here from using a single global AppConfig to allowing multiple # configurations. It is strongly recommended to pass the app_config to this function. -def choose_precision(device: torch.device, app_config: Optional[InvokeAIAppConfig] = None) -> str: +def choose_precision( + device: torch.device, app_config: Optional[InvokeAIAppConfig] = None +) -> Literal["float32", "float16", "bfloat16"]: """Return an appropriate precision for the given torch device.""" app_config = app_config or config if device.type == "cuda": diff --git a/invokeai/configs/INITIAL_MODELS.yaml b/invokeai/configs/INITIAL_MODELS.yaml index c230665e3a6..ca2283ab811 100644 --- a/invokeai/configs/INITIAL_MODELS.yaml +++ b/invokeai/configs/INITIAL_MODELS.yaml @@ -1,153 +1,157 @@ # This file predefines a few models that the user may want to install. sd-1/main/stable-diffusion-v1-5: description: Stable Diffusion version 1.5 diffusers model (4.27 GB) - repo_id: runwayml/stable-diffusion-v1-5 + source: runwayml/stable-diffusion-v1-5 recommended: True default: True sd-1/main/stable-diffusion-v1-5-inpainting: description: RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB) - repo_id: runwayml/stable-diffusion-inpainting + source: runwayml/stable-diffusion-inpainting recommended: True sd-2/main/stable-diffusion-2-1: description: Stable Diffusion version 2.1 diffusers model, trained on 768 pixel images (5.21 GB) - repo_id: stabilityai/stable-diffusion-2-1 + source: stabilityai/stable-diffusion-2-1 recommended: False sd-2/main/stable-diffusion-2-inpainting: description: Stable Diffusion version 2.0 inpainting model (5.21 GB) - repo_id: stabilityai/stable-diffusion-2-inpainting + source: stabilityai/stable-diffusion-2-inpainting recommended: False sdxl/main/stable-diffusion-xl-base-1-0: description: Stable Diffusion XL base model (12 GB) - repo_id: stabilityai/stable-diffusion-xl-base-1.0 + source: stabilityai/stable-diffusion-xl-base-1.0 recommended: True sdxl-refiner/main/stable-diffusion-xl-refiner-1-0: description: Stable Diffusion XL refiner model (12 GB) - repo_id: stabilityai/stable-diffusion-xl-refiner-1.0 + source: stabilityai/stable-diffusion-xl-refiner-1.0 recommended: False -sdxl/vae/sdxl-1-0-vae-fix: - description: Fine tuned version of the SDXL-1.0 VAE - repo_id: madebyollin/sdxl-vae-fp16-fix +sdxl/vae/sdxl-vae-fp16-fix: + description: Version of the SDXL-1.0 VAE that works in half precision mode + source: madebyollin/sdxl-vae-fp16-fix recommended: True sd-1/main/Analog-Diffusion: description: An SD-1.5 model trained on diverse analog photographs (2.13 GB) - repo_id: wavymulder/Analog-Diffusion + source: wavymulder/Analog-Diffusion recommended: False -sd-1/main/Deliberate_v5: +sd-1/main/Deliberate: description: Versatile model that produces detailed images up to 768px (4.27 GB) - path: https://huggingface.co/XpucT/Deliberate/resolve/main/Deliberate_v5.safetensors + source: XpucT/Deliberate recommended: False sd-1/main/Dungeons-and-Diffusion: description: Dungeons & Dragons characters (2.13 GB) - repo_id: 0xJustin/Dungeons-and-Diffusion + source: 0xJustin/Dungeons-and-Diffusion recommended: False sd-1/main/dreamlike-photoreal-2: description: A photorealistic model trained on 768 pixel images based on SD 1.5 (2.13 GB) - repo_id: dreamlike-art/dreamlike-photoreal-2.0 + source: dreamlike-art/dreamlike-photoreal-2.0 recommended: False sd-1/main/Inkpunk-Diffusion: description: Stylized illustrations inspired by Gorillaz, FLCL and Shinkawa; prompt with "nvinkpunk" (4.27 GB) - repo_id: Envvi/Inkpunk-Diffusion + source: Envvi/Inkpunk-Diffusion recommended: False sd-1/main/openjourney: description: An SD 1.5 model fine tuned on Midjourney; prompt with "mdjrny-v4 style" (2.13 GB) - repo_id: prompthero/openjourney + source: prompthero/openjourney recommended: False sd-1/main/seek.art_MEGA: - repo_id: coreco/seek.art_MEGA + source: coreco/seek.art_MEGA description: A general use SD-1.5 "anything" model that supports multiple styles (2.1 GB) recommended: False sd-1/main/trinart_stable_diffusion_v2: description: An SD-1.5 model finetuned with ~40K assorted high resolution manga/anime-style images (2.13 GB) - repo_id: naclbit/trinart_stable_diffusion_v2 + source: naclbit/trinart_stable_diffusion_v2 recommended: False sd-1/controlnet/qrcode_monster: - repo_id: monster-labs/control_v1p_sd15_qrcode_monster + source: monster-labs/control_v1p_sd15_qrcode_monster subfolder: v2 sd-1/controlnet/canny: - repo_id: lllyasviel/control_v11p_sd15_canny + source: lllyasviel/control_v11p_sd15_canny recommended: True sd-1/controlnet/inpaint: - repo_id: lllyasviel/control_v11p_sd15_inpaint + source: lllyasviel/control_v11p_sd15_inpaint sd-1/controlnet/mlsd: - repo_id: lllyasviel/control_v11p_sd15_mlsd + source: lllyasviel/control_v11p_sd15_mlsd sd-1/controlnet/depth: - repo_id: lllyasviel/control_v11f1p_sd15_depth + source: lllyasviel/control_v11f1p_sd15_depth recommended: True sd-1/controlnet/normal_bae: - repo_id: lllyasviel/control_v11p_sd15_normalbae + source: lllyasviel/control_v11p_sd15_normalbae sd-1/controlnet/seg: - repo_id: lllyasviel/control_v11p_sd15_seg + source: lllyasviel/control_v11p_sd15_seg sd-1/controlnet/lineart: - repo_id: lllyasviel/control_v11p_sd15_lineart + source: lllyasviel/control_v11p_sd15_lineart recommended: True sd-1/controlnet/lineart_anime: - repo_id: lllyasviel/control_v11p_sd15s2_lineart_anime + source: lllyasviel/control_v11p_sd15s2_lineart_anime sd-1/controlnet/openpose: - repo_id: lllyasviel/control_v11p_sd15_openpose + source: lllyasviel/control_v11p_sd15_openpose recommended: True sd-1/controlnet/scribble: - repo_id: lllyasviel/control_v11p_sd15_scribble + source: lllyasviel/control_v11p_sd15_scribble recommended: False sd-1/controlnet/softedge: - repo_id: lllyasviel/control_v11p_sd15_softedge + source: lllyasviel/control_v11p_sd15_softedge sd-1/controlnet/shuffle: - repo_id: lllyasviel/control_v11e_sd15_shuffle + source: lllyasviel/control_v11e_sd15_shuffle sd-1/controlnet/tile: - repo_id: lllyasviel/control_v11f1e_sd15_tile + source: lllyasviel/control_v11f1e_sd15_tile sd-1/controlnet/ip2p: - repo_id: lllyasviel/control_v11e_sd15_ip2p + source: lllyasviel/control_v11e_sd15_ip2p sd-1/t2i_adapter/canny-sd15: - repo_id: TencentARC/t2iadapter_canny_sd15v2 + source: TencentARC/t2iadapter_canny_sd15v2 sd-1/t2i_adapter/sketch-sd15: - repo_id: TencentARC/t2iadapter_sketch_sd15v2 + source: TencentARC/t2iadapter_sketch_sd15v2 sd-1/t2i_adapter/depth-sd15: - repo_id: TencentARC/t2iadapter_depth_sd15v2 + source: TencentARC/t2iadapter_depth_sd15v2 sd-1/t2i_adapter/zoedepth-sd15: - repo_id: TencentARC/t2iadapter_zoedepth_sd15v1 + source: TencentARC/t2iadapter_zoedepth_sd15v1 sdxl/t2i_adapter/canny-sdxl: - repo_id: TencentARC/t2i-adapter-canny-sdxl-1.0 + source: TencentARC/t2i-adapter-canny-sdxl-1.0 sdxl/t2i_adapter/zoedepth-sdxl: - repo_id: TencentARC/t2i-adapter-depth-zoe-sdxl-1.0 + source: TencentARC/t2i-adapter-depth-zoe-sdxl-1.0 sdxl/t2i_adapter/lineart-sdxl: - repo_id: TencentARC/t2i-adapter-lineart-sdxl-1.0 + source: TencentARC/t2i-adapter-lineart-sdxl-1.0 sdxl/t2i_adapter/sketch-sdxl: - repo_id: TencentARC/t2i-adapter-sketch-sdxl-1.0 + source: TencentARC/t2i-adapter-sketch-sdxl-1.0 sd-1/embedding/EasyNegative: - path: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors + source: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors recommended: True -sd-1/embedding/ahx-beta-453407d: - repo_id: sd-concepts-library/ahx-beta-453407d + description: A textual inversion to use in the negative prompt to reduce bad anatomy +sd-1/lora/FlatColor: + source: https://civitai.com/models/6433/loraflatcolor + recommended: True + description: A LoRA that generates scenery using solid blocks of color sd-1/lora/Ink scenery: - path: https://civitai.com/api/download/models/83390 + source: https://civitai.com/api/download/models/83390 + description: Generate india ink-like landscapes sd-1/ip_adapter/ip_adapter_sd15: - repo_id: InvokeAI/ip_adapter_sd15 + source: InvokeAI/ip_adapter_sd15 recommended: True requires: - InvokeAI/ip_adapter_sd_image_encoder description: IP-Adapter for SD 1.5 models sd-1/ip_adapter/ip_adapter_plus_sd15: - repo_id: InvokeAI/ip_adapter_plus_sd15 + source: InvokeAI/ip_adapter_plus_sd15 recommended: False requires: - InvokeAI/ip_adapter_sd_image_encoder description: Refined IP-Adapter for SD 1.5 models sd-1/ip_adapter/ip_adapter_plus_face_sd15: - repo_id: InvokeAI/ip_adapter_plus_face_sd15 + source: InvokeAI/ip_adapter_plus_face_sd15 recommended: False requires: - InvokeAI/ip_adapter_sd_image_encoder description: Refined IP-Adapter for SD 1.5 models, adapted for faces sdxl/ip_adapter/ip_adapter_sdxl: - repo_id: InvokeAI/ip_adapter_sdxl + source: InvokeAI/ip_adapter_sdxl recommended: False requires: - InvokeAI/ip_adapter_sdxl_image_encoder description: IP-Adapter for SDXL models any/clip_vision/ip_adapter_sd_image_encoder: - repo_id: InvokeAI/ip_adapter_sd_image_encoder + source: InvokeAI/ip_adapter_sd_image_encoder recommended: False description: Required model for using IP-Adapters with SD-1/2 models any/clip_vision/ip_adapter_sdxl_image_encoder: - repo_id: InvokeAI/ip_adapter_sdxl_image_encoder + source: InvokeAI/ip_adapter_sdxl_image_encoder recommended: False description: Required model for using IP-Adapters with SDXL models diff --git a/invokeai/configs/INITIAL_MODELS2.yaml b/invokeai/configs/INITIAL_MODELS.yaml.OLD similarity index 59% rename from invokeai/configs/INITIAL_MODELS2.yaml rename to invokeai/configs/INITIAL_MODELS.yaml.OLD index ca2283ab811..c230665e3a6 100644 --- a/invokeai/configs/INITIAL_MODELS2.yaml +++ b/invokeai/configs/INITIAL_MODELS.yaml.OLD @@ -1,157 +1,153 @@ # This file predefines a few models that the user may want to install. sd-1/main/stable-diffusion-v1-5: description: Stable Diffusion version 1.5 diffusers model (4.27 GB) - source: runwayml/stable-diffusion-v1-5 + repo_id: runwayml/stable-diffusion-v1-5 recommended: True default: True sd-1/main/stable-diffusion-v1-5-inpainting: description: RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB) - source: runwayml/stable-diffusion-inpainting + repo_id: runwayml/stable-diffusion-inpainting recommended: True sd-2/main/stable-diffusion-2-1: description: Stable Diffusion version 2.1 diffusers model, trained on 768 pixel images (5.21 GB) - source: stabilityai/stable-diffusion-2-1 + repo_id: stabilityai/stable-diffusion-2-1 recommended: False sd-2/main/stable-diffusion-2-inpainting: description: Stable Diffusion version 2.0 inpainting model (5.21 GB) - source: stabilityai/stable-diffusion-2-inpainting + repo_id: stabilityai/stable-diffusion-2-inpainting recommended: False sdxl/main/stable-diffusion-xl-base-1-0: description: Stable Diffusion XL base model (12 GB) - source: stabilityai/stable-diffusion-xl-base-1.0 + repo_id: stabilityai/stable-diffusion-xl-base-1.0 recommended: True sdxl-refiner/main/stable-diffusion-xl-refiner-1-0: description: Stable Diffusion XL refiner model (12 GB) - source: stabilityai/stable-diffusion-xl-refiner-1.0 + repo_id: stabilityai/stable-diffusion-xl-refiner-1.0 recommended: False -sdxl/vae/sdxl-vae-fp16-fix: - description: Version of the SDXL-1.0 VAE that works in half precision mode - source: madebyollin/sdxl-vae-fp16-fix +sdxl/vae/sdxl-1-0-vae-fix: + description: Fine tuned version of the SDXL-1.0 VAE + repo_id: madebyollin/sdxl-vae-fp16-fix recommended: True sd-1/main/Analog-Diffusion: description: An SD-1.5 model trained on diverse analog photographs (2.13 GB) - source: wavymulder/Analog-Diffusion + repo_id: wavymulder/Analog-Diffusion recommended: False -sd-1/main/Deliberate: +sd-1/main/Deliberate_v5: description: Versatile model that produces detailed images up to 768px (4.27 GB) - source: XpucT/Deliberate + path: https://huggingface.co/XpucT/Deliberate/resolve/main/Deliberate_v5.safetensors recommended: False sd-1/main/Dungeons-and-Diffusion: description: Dungeons & Dragons characters (2.13 GB) - source: 0xJustin/Dungeons-and-Diffusion + repo_id: 0xJustin/Dungeons-and-Diffusion recommended: False sd-1/main/dreamlike-photoreal-2: description: A photorealistic model trained on 768 pixel images based on SD 1.5 (2.13 GB) - source: dreamlike-art/dreamlike-photoreal-2.0 + repo_id: dreamlike-art/dreamlike-photoreal-2.0 recommended: False sd-1/main/Inkpunk-Diffusion: description: Stylized illustrations inspired by Gorillaz, FLCL and Shinkawa; prompt with "nvinkpunk" (4.27 GB) - source: Envvi/Inkpunk-Diffusion + repo_id: Envvi/Inkpunk-Diffusion recommended: False sd-1/main/openjourney: description: An SD 1.5 model fine tuned on Midjourney; prompt with "mdjrny-v4 style" (2.13 GB) - source: prompthero/openjourney + repo_id: prompthero/openjourney recommended: False sd-1/main/seek.art_MEGA: - source: coreco/seek.art_MEGA + repo_id: coreco/seek.art_MEGA description: A general use SD-1.5 "anything" model that supports multiple styles (2.1 GB) recommended: False sd-1/main/trinart_stable_diffusion_v2: description: An SD-1.5 model finetuned with ~40K assorted high resolution manga/anime-style images (2.13 GB) - source: naclbit/trinart_stable_diffusion_v2 + repo_id: naclbit/trinart_stable_diffusion_v2 recommended: False sd-1/controlnet/qrcode_monster: - source: monster-labs/control_v1p_sd15_qrcode_monster + repo_id: monster-labs/control_v1p_sd15_qrcode_monster subfolder: v2 sd-1/controlnet/canny: - source: lllyasviel/control_v11p_sd15_canny + repo_id: lllyasviel/control_v11p_sd15_canny recommended: True sd-1/controlnet/inpaint: - source: lllyasviel/control_v11p_sd15_inpaint + repo_id: lllyasviel/control_v11p_sd15_inpaint sd-1/controlnet/mlsd: - source: lllyasviel/control_v11p_sd15_mlsd + repo_id: lllyasviel/control_v11p_sd15_mlsd sd-1/controlnet/depth: - source: lllyasviel/control_v11f1p_sd15_depth + repo_id: lllyasviel/control_v11f1p_sd15_depth recommended: True sd-1/controlnet/normal_bae: - source: lllyasviel/control_v11p_sd15_normalbae + repo_id: lllyasviel/control_v11p_sd15_normalbae sd-1/controlnet/seg: - source: lllyasviel/control_v11p_sd15_seg + repo_id: lllyasviel/control_v11p_sd15_seg sd-1/controlnet/lineart: - source: lllyasviel/control_v11p_sd15_lineart + repo_id: lllyasviel/control_v11p_sd15_lineart recommended: True sd-1/controlnet/lineart_anime: - source: lllyasviel/control_v11p_sd15s2_lineart_anime + repo_id: lllyasviel/control_v11p_sd15s2_lineart_anime sd-1/controlnet/openpose: - source: lllyasviel/control_v11p_sd15_openpose + repo_id: lllyasviel/control_v11p_sd15_openpose recommended: True sd-1/controlnet/scribble: - source: lllyasviel/control_v11p_sd15_scribble + repo_id: lllyasviel/control_v11p_sd15_scribble recommended: False sd-1/controlnet/softedge: - source: lllyasviel/control_v11p_sd15_softedge + repo_id: lllyasviel/control_v11p_sd15_softedge sd-1/controlnet/shuffle: - source: lllyasviel/control_v11e_sd15_shuffle + repo_id: lllyasviel/control_v11e_sd15_shuffle sd-1/controlnet/tile: - source: lllyasviel/control_v11f1e_sd15_tile + repo_id: lllyasviel/control_v11f1e_sd15_tile sd-1/controlnet/ip2p: - source: lllyasviel/control_v11e_sd15_ip2p + repo_id: lllyasviel/control_v11e_sd15_ip2p sd-1/t2i_adapter/canny-sd15: - source: TencentARC/t2iadapter_canny_sd15v2 + repo_id: TencentARC/t2iadapter_canny_sd15v2 sd-1/t2i_adapter/sketch-sd15: - source: TencentARC/t2iadapter_sketch_sd15v2 + repo_id: TencentARC/t2iadapter_sketch_sd15v2 sd-1/t2i_adapter/depth-sd15: - source: TencentARC/t2iadapter_depth_sd15v2 + repo_id: TencentARC/t2iadapter_depth_sd15v2 sd-1/t2i_adapter/zoedepth-sd15: - source: TencentARC/t2iadapter_zoedepth_sd15v1 + repo_id: TencentARC/t2iadapter_zoedepth_sd15v1 sdxl/t2i_adapter/canny-sdxl: - source: TencentARC/t2i-adapter-canny-sdxl-1.0 + repo_id: TencentARC/t2i-adapter-canny-sdxl-1.0 sdxl/t2i_adapter/zoedepth-sdxl: - source: TencentARC/t2i-adapter-depth-zoe-sdxl-1.0 + repo_id: TencentARC/t2i-adapter-depth-zoe-sdxl-1.0 sdxl/t2i_adapter/lineart-sdxl: - source: TencentARC/t2i-adapter-lineart-sdxl-1.0 + repo_id: TencentARC/t2i-adapter-lineart-sdxl-1.0 sdxl/t2i_adapter/sketch-sdxl: - source: TencentARC/t2i-adapter-sketch-sdxl-1.0 + repo_id: TencentARC/t2i-adapter-sketch-sdxl-1.0 sd-1/embedding/EasyNegative: - source: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors + path: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors recommended: True - description: A textual inversion to use in the negative prompt to reduce bad anatomy -sd-1/lora/FlatColor: - source: https://civitai.com/models/6433/loraflatcolor - recommended: True - description: A LoRA that generates scenery using solid blocks of color +sd-1/embedding/ahx-beta-453407d: + repo_id: sd-concepts-library/ahx-beta-453407d sd-1/lora/Ink scenery: - source: https://civitai.com/api/download/models/83390 - description: Generate india ink-like landscapes + path: https://civitai.com/api/download/models/83390 sd-1/ip_adapter/ip_adapter_sd15: - source: InvokeAI/ip_adapter_sd15 + repo_id: InvokeAI/ip_adapter_sd15 recommended: True requires: - InvokeAI/ip_adapter_sd_image_encoder description: IP-Adapter for SD 1.5 models sd-1/ip_adapter/ip_adapter_plus_sd15: - source: InvokeAI/ip_adapter_plus_sd15 + repo_id: InvokeAI/ip_adapter_plus_sd15 recommended: False requires: - InvokeAI/ip_adapter_sd_image_encoder description: Refined IP-Adapter for SD 1.5 models sd-1/ip_adapter/ip_adapter_plus_face_sd15: - source: InvokeAI/ip_adapter_plus_face_sd15 + repo_id: InvokeAI/ip_adapter_plus_face_sd15 recommended: False requires: - InvokeAI/ip_adapter_sd_image_encoder description: Refined IP-Adapter for SD 1.5 models, adapted for faces sdxl/ip_adapter/ip_adapter_sdxl: - source: InvokeAI/ip_adapter_sdxl + repo_id: InvokeAI/ip_adapter_sdxl recommended: False requires: - InvokeAI/ip_adapter_sdxl_image_encoder description: IP-Adapter for SDXL models any/clip_vision/ip_adapter_sd_image_encoder: - source: InvokeAI/ip_adapter_sd_image_encoder + repo_id: InvokeAI/ip_adapter_sd_image_encoder recommended: False description: Required model for using IP-Adapters with SD-1/2 models any/clip_vision/ip_adapter_sdxl_image_encoder: - source: InvokeAI/ip_adapter_sdxl_image_encoder + repo_id: InvokeAI/ip_adapter_sdxl_image_encoder recommended: False description: Required model for using IP-Adapters with SDXL models diff --git a/invokeai/frontend/install/model_install.py b/invokeai/frontend/install/model_install.py index e23538ffd66..22b132370e6 100644 --- a/invokeai/frontend/install/model_install.py +++ b/invokeai/frontend/install/model_install.py @@ -6,47 +6,45 @@ """ This is the npyscreen frontend to the model installation application. -The work is actually done in backend code in model_install_backend.py. +It is currently named model_install2.py, but will ultimately replace model_install.py. """ import argparse import curses -import logging import sys -import textwrap import traceback +import warnings from argparse import Namespace -from multiprocessing import Process -from multiprocessing.connection import Connection, Pipe -from pathlib import Path from shutil import get_terminal_size -from typing import Optional +from typing import Any, Dict, List, Optional, Set import npyscreen import torch from npyscreen import widget from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.install.model_install_backend import InstallSelections, ModelInstall, SchedulerPredictionType -from invokeai.backend.model_management import ModelManager, ModelType +from invokeai.app.services.model_install import ModelInstallServiceBase +from invokeai.backend.install.install_helper import InstallHelper, InstallSelections, UnifiedModelInfo +from invokeai.backend.model_manager import ModelType from invokeai.backend.util import choose_precision, choose_torch_device from invokeai.backend.util.logging import InvokeAILogger from invokeai.frontend.install.widgets import ( MIN_COLS, MIN_LINES, - BufferBox, CenteredTitleText, CyclingForm, MultiSelectColumns, SingleSelectColumns, TextBox, WindowTooSmallException, - select_stable_diffusion_config_file, set_min_terminal_size, ) +warnings.filterwarnings("ignore", category=UserWarning) # noqa: E402 config = InvokeAIAppConfig.get_config() -logger = InvokeAILogger.get_logger() +logger = InvokeAILogger.get_logger("ModelInstallService") +logger.setLevel("WARNING") +# logger.setLevel('DEBUG') # build a table mapping all non-printable characters to None # for stripping control characters @@ -58,44 +56,42 @@ def make_printable(s: str) -> str: - """Replace non-printable characters in a string""" + """Replace non-printable characters in a string.""" return s.translate(NOPRINT_TRANS_TABLE) class addModelsForm(CyclingForm, npyscreen.FormMultiPage): + """Main form for interactive TUI.""" + # for responsive resizing set to False, but this seems to cause a crash! FIX_MINIMUM_SIZE_WHEN_CREATED = True # for persistence current_tab = 0 - def __init__(self, parentApp, name, multipage=False, *args, **keywords): + def __init__(self, parentApp: npyscreen.NPSAppManaged, name: str, multipage: bool = False, **keywords: Any): self.multipage = multipage self.subprocess = None - super().__init__(parentApp=parentApp, name=name, *args, **keywords) # noqa: B026 # TODO: maybe this is bad? + super().__init__(parentApp=parentApp, name=name, **keywords) - def create(self): + def create(self) -> None: + self.installer = self.parentApp.install_helper.installer + self.model_labels = self._get_model_labels() self.keypress_timeout = 10 self.counter = 0 self.subprocess_connection = None - if not config.model_conf_path.exists(): - with open(config.model_conf_path, "w") as file: - print("# InvokeAI model configuration file", file=file) - self.installer = ModelInstall(config) - self.all_models = self.installer.all_models() - self.starter_models = self.installer.starter_models() - self.model_labels = self._get_model_labels() window_width, window_height = get_terminal_size() - self.nextrely -= 1 + # npyscreen has no typing hints + self.nextrely -= 1 # type: ignore self.add_widget_intelligent( npyscreen.FixedText, value="Use ctrl-N and ctrl-P to move to the ext and

revious fields. Cursor keys navigate, and selects.", editable=False, color="CAUTION", ) - self.nextrely += 1 + self.nextrely += 1 # type: ignore self.tabs = self.add_widget_intelligent( SingleSelectColumns, values=[ @@ -115,9 +111,9 @@ def create(self): ) self.tabs.on_changed = self._toggle_tables - top_of_table = self.nextrely + top_of_table = self.nextrely # type: ignore self.starter_pipelines = self.add_starter_pipelines() - bottom_of_table = self.nextrely + bottom_of_table = self.nextrely # type: ignore self.nextrely = top_of_table self.pipeline_models = self.add_pipeline_widgets( @@ -162,15 +158,7 @@ def create(self): self.nextrely = bottom_of_table + 1 - self.monitor = self.add_widget_intelligent( - BufferBox, - name="Log Messages", - editable=False, - max_height=6, - ) - self.nextrely += 1 - done_label = "APPLY CHANGES" back_label = "BACK" cancel_label = "CANCEL" current_position = self.nextrely @@ -186,14 +174,8 @@ def create(self): npyscreen.ButtonPress, name=cancel_label, when_pressed_function=self.on_cancel ) self.nextrely = current_position - self.ok_button = self.add_widget_intelligent( - npyscreen.ButtonPress, - name=done_label, - relx=(window_width - len(done_label)) // 2, - when_pressed_function=self.on_execute, - ) - label = "APPLY CHANGES & EXIT" + label = "APPLY CHANGES" self.nextrely = current_position self.done = self.add_widget_intelligent( npyscreen.ButtonPress, @@ -210,17 +192,16 @@ def create(self): ############# diffusers tab ########## def add_starter_pipelines(self) -> dict[str, npyscreen.widget]: """Add widgets responsible for selecting diffusers models""" - widgets = {} - models = self.all_models - starters = self.starter_models - starter_model_labels = self.model_labels + widgets: Dict[str, npyscreen.widget] = {} - self.installed_models = sorted([x for x in starters if models[x].installed]) + all_models = self.all_models # master dict of all models, indexed by key + model_list = [x for x in self.starter_models if all_models[x].type in ["main", "vae"]] + model_labels = [self.model_labels[x] for x in model_list] widgets.update( label1=self.add_widget_intelligent( CenteredTitleText, - name="Select from a starter set of Stable Diffusion models from HuggingFace.", + name="Select from a starter set of Stable Diffusion models from HuggingFace and Civitae.", editable=False, labelColor="CAUTION", ) @@ -230,23 +211,24 @@ def add_starter_pipelines(self) -> dict[str, npyscreen.widget]: # if user has already installed some initial models, then don't patronize them # by showing more recommendations show_recommended = len(self.installed_models) == 0 - keys = [x for x in models.keys() if x in starters] + + checked = [ + model_list.index(x) + for x in model_list + if (show_recommended and all_models[x].recommended) or all_models[x].installed + ] widgets.update( models_selected=self.add_widget_intelligent( MultiSelectColumns, columns=1, name="Install Starter Models", - values=[starter_model_labels[x] for x in keys], - value=[ - keys.index(x) - for x in keys - if (show_recommended and models[x].recommended) or (x in self.installed_models) - ], - max_height=len(starters) + 1, + values=model_labels, + value=checked, + max_height=len(model_list) + 1, relx=4, scroll_exit=True, ), - models=keys, + models=model_list, ) self.nextrely += 1 @@ -257,14 +239,18 @@ def add_model_widgets( self, model_type: ModelType, window_width: int = 120, - install_prompt: str = None, - exclude: set = None, + install_prompt: Optional[str] = None, + exclude: Optional[Set[str]] = None, ) -> dict[str, npyscreen.widget]: """Generic code to create model selection widgets""" if exclude is None: exclude = set() - widgets = {} - model_list = [x for x in self.all_models if self.all_models[x].model_type == model_type and x not in exclude] + widgets: Dict[str, npyscreen.widget] = {} + all_models = self.all_models + model_list = sorted( + [x for x in all_models if all_models[x].type == model_type and x not in exclude], + key=lambda x: all_models[x].name or "", + ) model_labels = [self.model_labels[x] for x in model_list] show_recommended = len(self.installed_models) == 0 @@ -300,7 +286,7 @@ def add_model_widgets( value=[ model_list.index(x) for x in model_list - if (show_recommended and self.all_models[x].recommended) or self.all_models[x].installed + if (show_recommended and all_models[x].recommended) or all_models[x].installed ], max_height=len(model_list) // columns + 1, relx=4, @@ -324,7 +310,7 @@ def add_model_widgets( download_ids=self.add_widget_intelligent( TextBox, name="Additional URLs, or HuggingFace repo_ids to install (Space separated. Use shift-control-V to paste):", - max_height=4, + max_height=6, scroll_exit=True, editable=True, ) @@ -349,13 +335,13 @@ def add_pipeline_widgets( return widgets - def resize(self): + def resize(self) -> None: super().resize() if s := self.starter_pipelines.get("models_selected"): - keys = [x for x in self.all_models.keys() if x in self.starter_models] - s.values = [self.model_labels[x] for x in keys] + if model_list := self.starter_pipelines.get("models"): + s.values = [self.model_labels[x] for x in model_list] - def _toggle_tables(self, value=None): + def _toggle_tables(self, value: List[int]) -> None: selected_tab = value[0] widgets = [ self.starter_pipelines, @@ -385,17 +371,18 @@ def _toggle_tables(self, value=None): self.display() def _get_model_labels(self) -> dict[str, str]: + """Return a list of trimmed labels for all models.""" window_width, window_height = get_terminal_size() checkbox_width = 4 spacing_width = 2 + result = {} models = self.all_models - label_width = max([len(models[x].name) for x in models]) + label_width = max([len(models[x].name or "") for x in self.starter_models]) description_width = window_width - label_width - checkbox_width - spacing_width - result = {} - for x in models.keys(): - description = models[x].description + for key in self.all_models: + description = models[key].description description = ( description[0 : description_width - 3] + "..." if description and len(description) > description_width @@ -403,7 +390,8 @@ def _get_model_labels(self) -> dict[str, str]: if description else "" ) - result[x] = f"%-{label_width}s %s" % (models[x].name, description) + result[key] = f"%-{label_width}s %s" % (models[key].name, description) + return result def _get_columns(self) -> int: @@ -413,50 +401,40 @@ def _get_columns(self) -> int: def confirm_deletions(self, selections: InstallSelections) -> bool: remove_models = selections.remove_models - if len(remove_models) > 0: - mods = "\n".join([ModelManager.parse_key(x)[0] for x in remove_models]) - return npyscreen.notify_ok_cancel( + if remove_models: + model_names = [self.all_models[x].name or "" for x in remove_models] + mods = "\n".join(model_names) + is_ok = npyscreen.notify_ok_cancel( f"These unchecked models will be deleted from disk. Continue?\n---------\n{mods}" ) + assert isinstance(is_ok, bool) # npyscreen doesn't have return type annotations + return is_ok else: return True - def on_execute(self): - self.marshall_arguments() - app = self.parentApp - if not self.confirm_deletions(app.install_selections): - return + @property + def all_models(self) -> Dict[str, UnifiedModelInfo]: + # npyscreen doesn't having typing hints + return self.parentApp.install_helper.all_models # type: ignore - self.monitor.entry_widget.buffer(["Processing..."], scroll_end=True) - self.ok_button.hidden = True - self.display() + @property + def starter_models(self) -> List[str]: + return self.parentApp.install_helper._starter_models # type: ignore - # TO DO: Spawn a worker thread, not a subprocess - parent_conn, child_conn = Pipe() - p = Process( - target=process_and_execute, - kwargs={ - "opt": app.program_opts, - "selections": app.install_selections, - "conn_out": child_conn, - }, - ) - p.start() - child_conn.close() - self.subprocess_connection = parent_conn - self.subprocess = p - app.install_selections = InstallSelections() + @property + def installed_models(self) -> List[str]: + return self.parentApp.install_helper._installed_models # type: ignore - def on_back(self): + def on_back(self) -> None: self.parentApp.switchFormPrevious() self.editing = False - def on_cancel(self): + def on_cancel(self) -> None: self.parentApp.setNextForm(None) self.parentApp.user_cancelled = True self.editing = False - def on_done(self): + def on_done(self) -> None: self.marshall_arguments() if not self.confirm_deletions(self.parentApp.install_selections): return @@ -464,77 +442,7 @@ def on_done(self): self.parentApp.user_cancelled = False self.editing = False - ########## This routine monitors the child process that is performing model installation and removal ##### - def while_waiting(self): - """Called during idle periods. Main task is to update the Log Messages box with messages - from the child process that does the actual installation/removal""" - c = self.subprocess_connection - if not c: - return - - monitor_widget = self.monitor.entry_widget - while c.poll(): - try: - data = c.recv_bytes().decode("utf-8") - data.strip("\n") - - # processing child is requesting user input to select the - # right configuration file - if data.startswith("*need v2 config"): - _, model_path, *_ = data.split(":", 2) - self._return_v2_config(model_path) - - # processing child is done - elif data == "*done*": - self._close_subprocess_and_regenerate_form() - break - - # update the log message box - else: - data = make_printable(data) - data = data.replace("[A", "") - monitor_widget.buffer( - textwrap.wrap( - data, - width=monitor_widget.width, - subsequent_indent=" ", - ), - scroll_end=True, - ) - self.display() - except (EOFError, OSError): - self.subprocess_connection = None - - def _return_v2_config(self, model_path: str): - c = self.subprocess_connection - model_name = Path(model_path).name - message = select_stable_diffusion_config_file(model_name=model_name) - c.send_bytes(message.encode("utf-8")) - - def _close_subprocess_and_regenerate_form(self): - app = self.parentApp - self.subprocess_connection.close() - self.subprocess_connection = None - self.monitor.entry_widget.buffer(["** Action Complete **"]) - self.display() - - # rebuild the form, saving and restoring some of the fields that need to be preserved. - saved_messages = self.monitor.entry_widget.values - - app.main_form = app.addForm( - "MAIN", - addModelsForm, - name="Install Stable Diffusion Models", - multipage=self.multipage, - ) - app.switchForm("MAIN") - - app.main_form.monitor.entry_widget.values = saved_messages - app.main_form.monitor.entry_widget.buffer([""], scroll_end=True) - # app.main_form.pipeline_models['autoload_directory'].value = autoload_dir - # app.main_form.pipeline_models['autoscan_on_startup'].value = autoscan - - def marshall_arguments(self): + def marshall_arguments(self) -> None: """ Assemble arguments and store as attributes of the application: .starter_models: dict of model names to install from INITIAL_CONFIGURE.yaml @@ -564,46 +472,24 @@ def marshall_arguments(self): models_to_install = [x for x in selected if not self.all_models[x].installed] models_to_remove = [x for x in section["models"] if x not in selected and self.all_models[x].installed] selections.remove_models.extend(models_to_remove) - selections.install_models.extend( - all_models[x].path or all_models[x].repo_id - for x in models_to_install - if all_models[x].path or all_models[x].repo_id - ) + selections.install_models.extend([all_models[x] for x in models_to_install]) # models located in the 'download_ids" section for section in ui_sections: if downloads := section.get("download_ids"): - selections.install_models.extend(downloads.value.split()) - - # NOT NEEDED - DONE IN BACKEND NOW - # # special case for the ipadapter_models. If any of the adapters are - # # chosen, then we add the corresponding encoder(s) to the install list. - # section = self.ipadapter_models - # if section.get("models_selected"): - # selected_adapters = [ - # self.all_models[section["models"][x]].name for x in section.get("models_selected").value - # ] - # encoders = [] - # if any(["sdxl" in x for x in selected_adapters]): - # encoders.append("ip_adapter_sdxl_image_encoder") - # if any(["sd15" in x for x in selected_adapters]): - # encoders.append("ip_adapter_sd_image_encoder") - # for encoder in encoders: - # key = f"any/clip_vision/{encoder}" - # repo_id = f"InvokeAI/{encoder}" - # if key not in self.all_models: - # selections.install_models.append(repo_id) - - -class AddModelApplication(npyscreen.NPSAppManaged): - def __init__(self, opt): + models = [UnifiedModelInfo(source=x) for x in downloads.value.split()] + selections.install_models.extend(models) + + +class AddModelApplication(npyscreen.NPSAppManaged): # type: ignore + def __init__(self, opt: Namespace, install_helper: InstallHelper): super().__init__() self.program_opts = opt self.user_cancelled = False - # self.autoload_pending = True self.install_selections = InstallSelections() + self.install_helper = install_helper - def onStart(self): + def onStart(self) -> None: npyscreen.setTheme(npyscreen.Themes.DefaultTheme) self.main_form = self.addForm( "MAIN", @@ -613,138 +499,62 @@ def onStart(self): ) -class StderrToMessage: - def __init__(self, connection: Connection): - self.connection = connection - - def write(self, data: str): - self.connection.send_bytes(data.encode("utf-8")) - - def flush(self): - pass +def list_models(installer: ModelInstallServiceBase, model_type: ModelType): + """Print out all models of type model_type.""" + models = installer.record_store.search_by_attr(model_type=model_type) + print(f"Installed models of type `{model_type}`:") + for model in models: + path = (config.models_path / model.path).resolve() + print(f"{model.name:40}{model.base.value:5}{model.type.value:8}{model.format.value:12}{path}") # -------------------------------------------------------- -def ask_user_for_prediction_type(model_path: Path, tui_conn: Connection = None) -> SchedulerPredictionType: - if tui_conn: - logger.debug("Waiting for user response...") - return _ask_user_for_pt_tui(model_path, tui_conn) - else: - return _ask_user_for_pt_cmdline(model_path) - - -def _ask_user_for_pt_cmdline(model_path: Path) -> Optional[SchedulerPredictionType]: - choices = [SchedulerPredictionType.Epsilon, SchedulerPredictionType.VPrediction, None] - print( - f""" -Please select the scheduler prediction type of the checkpoint named {model_path.name}: -[1] "epsilon" - most v1.5 models and v2 models trained on 512 pixel images -[2] "vprediction" - v2 models trained on 768 pixel images and a few v1.5 models -[3] Accept the best guess; you can fix it in the Web UI later -""" - ) - choice = None - ok = False - while not ok: - try: - choice = input("select [3]> ").strip() - if not choice: - return None - choice = choices[int(choice) - 1] - ok = True - except (ValueError, IndexError): - print(f"{choice} is not a valid choice") - except EOFError: - return - return choice - - -def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection) -> SchedulerPredictionType: - tui_conn.send_bytes(f"*need v2 config for:{model_path}".encode("utf-8")) - # note that we don't do any status checking here - response = tui_conn.recv_bytes().decode("utf-8") - if response is None: - return None - elif response == "epsilon": - return SchedulerPredictionType.epsilon - elif response == "v": - return SchedulerPredictionType.VPrediction - elif response == "guess": - return None - else: - return None - - -# -------------------------------------------------------- -def process_and_execute( - opt: Namespace, - selections: InstallSelections, - conn_out: Connection = None, -): - # need to reinitialize config in subprocess - config = InvokeAIAppConfig.get_config() - args = ["--root", opt.root] if opt.root else [] - config.parse_args(args) - - # set up so that stderr is sent to conn_out - if conn_out: - translator = StderrToMessage(conn_out) - sys.stderr = translator - sys.stdout = translator - logger = InvokeAILogger.get_logger() - logger.handlers.clear() - logger.addHandler(logging.StreamHandler(translator)) - - installer = ModelInstall(config, prediction_type_helper=lambda x: ask_user_for_prediction_type(x, conn_out)) - installer.install(selections) - - if conn_out: - conn_out.send_bytes("*done*".encode("utf-8")) - conn_out.close() - - -# -------------------------------------------------------- -def select_and_download_models(opt: Namespace): +def select_and_download_models(opt: Namespace) -> None: + """Prompt user for install/delete selections and execute.""" precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device())) - config.precision = precision - installer = ModelInstall(config, prediction_type_helper=ask_user_for_prediction_type) + # unsure how to avoid a typing complaint in the next line: config.precision is an enumerated Literal + config.precision = precision # type: ignore + install_helper = InstallHelper(config, logger) + installer = install_helper.installer + if opt.list_models: - installer.list_models(opt.list_models) + list_models(installer, opt.list_models) + elif opt.add or opt.delete: - selections = InstallSelections(install_models=opt.add or [], remove_models=opt.delete or []) - installer.install(selections) + selections = InstallSelections( + install_models=[UnifiedModelInfo(source=x) for x in (opt.add or [])], remove_models=opt.delete or [] + ) + install_helper.add_or_delete(selections) + elif opt.default_only: - selections = InstallSelections(install_models=installer.default_model()) - installer.install(selections) + default_model = install_helper.default_model() + assert default_model is not None + selections = InstallSelections(install_models=[default_model]) + install_helper.add_or_delete(selections) + elif opt.yes_to_all: - selections = InstallSelections(install_models=installer.recommended_models()) - installer.install(selections) + selections = InstallSelections(install_models=install_helper.recommended_models()) + install_helper.add_or_delete(selections) # this is where the TUI is called else: - # needed to support the probe() method running under a subprocess - torch.multiprocessing.set_start_method("spawn") - if not set_min_terminal_size(MIN_COLS, MIN_LINES): raise WindowTooSmallException( "Could not increase terminal size. Try running again with a larger window or smaller font size." ) - installApp = AddModelApplication(opt) + installApp = AddModelApplication(opt, install_helper) try: installApp.run() - except KeyboardInterrupt as e: - if hasattr(installApp, "main_form"): - if installApp.main_form.subprocess and installApp.main_form.subprocess.is_alive(): - logger.info("Terminating subprocesses") - installApp.main_form.subprocess.terminate() - installApp.main_form.subprocess = None - raise e - process_and_execute(opt, installApp.install_selections) + except KeyboardInterrupt: + print("Aborted...") + sys.exit(-1) + + install_helper.add_or_delete(installApp.install_selections) # ------------------------------------- -def main(): +def main() -> None: parser = argparse.ArgumentParser(description="InvokeAI model downloader") parser.add_argument( "--add", @@ -754,7 +564,7 @@ def main(): parser.add_argument( "--delete", nargs="*", - help="List of names of models to idelete", + help="List of names of models to delete. Use type:name to disambiguate, as in `controlnet:my_model`", ) parser.add_argument( "--full-precision", @@ -781,14 +591,6 @@ def main(): choices=[x.value for x in ModelType], help="list installed models", ) - parser.add_argument( - "--config_file", - "-c", - dest="config_file", - type=str, - default=None, - help="path to configuration file to create", - ) parser.add_argument( "--root_dir", dest="root", diff --git a/invokeai/frontend/install/model_install2.py b/invokeai/frontend/install/model_install.py.OLD similarity index 57% rename from invokeai/frontend/install/model_install2.py rename to invokeai/frontend/install/model_install.py.OLD index 22b132370e6..e23538ffd66 100644 --- a/invokeai/frontend/install/model_install2.py +++ b/invokeai/frontend/install/model_install.py.OLD @@ -6,45 +6,47 @@ """ This is the npyscreen frontend to the model installation application. -It is currently named model_install2.py, but will ultimately replace model_install.py. +The work is actually done in backend code in model_install_backend.py. """ import argparse import curses +import logging import sys +import textwrap import traceback -import warnings from argparse import Namespace +from multiprocessing import Process +from multiprocessing.connection import Connection, Pipe +from pathlib import Path from shutil import get_terminal_size -from typing import Any, Dict, List, Optional, Set +from typing import Optional import npyscreen import torch from npyscreen import widget from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.model_install import ModelInstallServiceBase -from invokeai.backend.install.install_helper import InstallHelper, InstallSelections, UnifiedModelInfo -from invokeai.backend.model_manager import ModelType +from invokeai.backend.install.model_install_backend import InstallSelections, ModelInstall, SchedulerPredictionType +from invokeai.backend.model_management import ModelManager, ModelType from invokeai.backend.util import choose_precision, choose_torch_device from invokeai.backend.util.logging import InvokeAILogger from invokeai.frontend.install.widgets import ( MIN_COLS, MIN_LINES, + BufferBox, CenteredTitleText, CyclingForm, MultiSelectColumns, SingleSelectColumns, TextBox, WindowTooSmallException, + select_stable_diffusion_config_file, set_min_terminal_size, ) -warnings.filterwarnings("ignore", category=UserWarning) # noqa: E402 config = InvokeAIAppConfig.get_config() -logger = InvokeAILogger.get_logger("ModelInstallService") -logger.setLevel("WARNING") -# logger.setLevel('DEBUG') +logger = InvokeAILogger.get_logger() # build a table mapping all non-printable characters to None # for stripping control characters @@ -56,42 +58,44 @@ def make_printable(s: str) -> str: - """Replace non-printable characters in a string.""" + """Replace non-printable characters in a string""" return s.translate(NOPRINT_TRANS_TABLE) class addModelsForm(CyclingForm, npyscreen.FormMultiPage): - """Main form for interactive TUI.""" - # for responsive resizing set to False, but this seems to cause a crash! FIX_MINIMUM_SIZE_WHEN_CREATED = True # for persistence current_tab = 0 - def __init__(self, parentApp: npyscreen.NPSAppManaged, name: str, multipage: bool = False, **keywords: Any): + def __init__(self, parentApp, name, multipage=False, *args, **keywords): self.multipage = multipage self.subprocess = None - super().__init__(parentApp=parentApp, name=name, **keywords) + super().__init__(parentApp=parentApp, name=name, *args, **keywords) # noqa: B026 # TODO: maybe this is bad? - def create(self) -> None: - self.installer = self.parentApp.install_helper.installer - self.model_labels = self._get_model_labels() + def create(self): self.keypress_timeout = 10 self.counter = 0 self.subprocess_connection = None + if not config.model_conf_path.exists(): + with open(config.model_conf_path, "w") as file: + print("# InvokeAI model configuration file", file=file) + self.installer = ModelInstall(config) + self.all_models = self.installer.all_models() + self.starter_models = self.installer.starter_models() + self.model_labels = self._get_model_labels() window_width, window_height = get_terminal_size() - # npyscreen has no typing hints - self.nextrely -= 1 # type: ignore + self.nextrely -= 1 self.add_widget_intelligent( npyscreen.FixedText, value="Use ctrl-N and ctrl-P to move to the ext and

revious fields. Cursor keys navigate, and selects.", editable=False, color="CAUTION", ) - self.nextrely += 1 # type: ignore + self.nextrely += 1 self.tabs = self.add_widget_intelligent( SingleSelectColumns, values=[ @@ -111,9 +115,9 @@ def create(self) -> None: ) self.tabs.on_changed = self._toggle_tables - top_of_table = self.nextrely # type: ignore + top_of_table = self.nextrely self.starter_pipelines = self.add_starter_pipelines() - bottom_of_table = self.nextrely # type: ignore + bottom_of_table = self.nextrely self.nextrely = top_of_table self.pipeline_models = self.add_pipeline_widgets( @@ -158,7 +162,15 @@ def create(self) -> None: self.nextrely = bottom_of_table + 1 + self.monitor = self.add_widget_intelligent( + BufferBox, + name="Log Messages", + editable=False, + max_height=6, + ) + self.nextrely += 1 + done_label = "APPLY CHANGES" back_label = "BACK" cancel_label = "CANCEL" current_position = self.nextrely @@ -174,8 +186,14 @@ def create(self) -> None: npyscreen.ButtonPress, name=cancel_label, when_pressed_function=self.on_cancel ) self.nextrely = current_position + self.ok_button = self.add_widget_intelligent( + npyscreen.ButtonPress, + name=done_label, + relx=(window_width - len(done_label)) // 2, + when_pressed_function=self.on_execute, + ) - label = "APPLY CHANGES" + label = "APPLY CHANGES & EXIT" self.nextrely = current_position self.done = self.add_widget_intelligent( npyscreen.ButtonPress, @@ -192,16 +210,17 @@ def create(self) -> None: ############# diffusers tab ########## def add_starter_pipelines(self) -> dict[str, npyscreen.widget]: """Add widgets responsible for selecting diffusers models""" - widgets: Dict[str, npyscreen.widget] = {} + widgets = {} + models = self.all_models + starters = self.starter_models + starter_model_labels = self.model_labels - all_models = self.all_models # master dict of all models, indexed by key - model_list = [x for x in self.starter_models if all_models[x].type in ["main", "vae"]] - model_labels = [self.model_labels[x] for x in model_list] + self.installed_models = sorted([x for x in starters if models[x].installed]) widgets.update( label1=self.add_widget_intelligent( CenteredTitleText, - name="Select from a starter set of Stable Diffusion models from HuggingFace and Civitae.", + name="Select from a starter set of Stable Diffusion models from HuggingFace.", editable=False, labelColor="CAUTION", ) @@ -211,24 +230,23 @@ def add_starter_pipelines(self) -> dict[str, npyscreen.widget]: # if user has already installed some initial models, then don't patronize them # by showing more recommendations show_recommended = len(self.installed_models) == 0 - - checked = [ - model_list.index(x) - for x in model_list - if (show_recommended and all_models[x].recommended) or all_models[x].installed - ] + keys = [x for x in models.keys() if x in starters] widgets.update( models_selected=self.add_widget_intelligent( MultiSelectColumns, columns=1, name="Install Starter Models", - values=model_labels, - value=checked, - max_height=len(model_list) + 1, + values=[starter_model_labels[x] for x in keys], + value=[ + keys.index(x) + for x in keys + if (show_recommended and models[x].recommended) or (x in self.installed_models) + ], + max_height=len(starters) + 1, relx=4, scroll_exit=True, ), - models=model_list, + models=keys, ) self.nextrely += 1 @@ -239,18 +257,14 @@ def add_model_widgets( self, model_type: ModelType, window_width: int = 120, - install_prompt: Optional[str] = None, - exclude: Optional[Set[str]] = None, + install_prompt: str = None, + exclude: set = None, ) -> dict[str, npyscreen.widget]: """Generic code to create model selection widgets""" if exclude is None: exclude = set() - widgets: Dict[str, npyscreen.widget] = {} - all_models = self.all_models - model_list = sorted( - [x for x in all_models if all_models[x].type == model_type and x not in exclude], - key=lambda x: all_models[x].name or "", - ) + widgets = {} + model_list = [x for x in self.all_models if self.all_models[x].model_type == model_type and x not in exclude] model_labels = [self.model_labels[x] for x in model_list] show_recommended = len(self.installed_models) == 0 @@ -286,7 +300,7 @@ def add_model_widgets( value=[ model_list.index(x) for x in model_list - if (show_recommended and all_models[x].recommended) or all_models[x].installed + if (show_recommended and self.all_models[x].recommended) or self.all_models[x].installed ], max_height=len(model_list) // columns + 1, relx=4, @@ -310,7 +324,7 @@ def add_model_widgets( download_ids=self.add_widget_intelligent( TextBox, name="Additional URLs, or HuggingFace repo_ids to install (Space separated. Use shift-control-V to paste):", - max_height=6, + max_height=4, scroll_exit=True, editable=True, ) @@ -335,13 +349,13 @@ def add_pipeline_widgets( return widgets - def resize(self) -> None: + def resize(self): super().resize() if s := self.starter_pipelines.get("models_selected"): - if model_list := self.starter_pipelines.get("models"): - s.values = [self.model_labels[x] for x in model_list] + keys = [x for x in self.all_models.keys() if x in self.starter_models] + s.values = [self.model_labels[x] for x in keys] - def _toggle_tables(self, value: List[int]) -> None: + def _toggle_tables(self, value=None): selected_tab = value[0] widgets = [ self.starter_pipelines, @@ -371,18 +385,17 @@ def _toggle_tables(self, value: List[int]) -> None: self.display() def _get_model_labels(self) -> dict[str, str]: - """Return a list of trimmed labels for all models.""" window_width, window_height = get_terminal_size() checkbox_width = 4 spacing_width = 2 - result = {} models = self.all_models - label_width = max([len(models[x].name or "") for x in self.starter_models]) + label_width = max([len(models[x].name) for x in models]) description_width = window_width - label_width - checkbox_width - spacing_width - for key in self.all_models: - description = models[key].description + result = {} + for x in models.keys(): + description = models[x].description description = ( description[0 : description_width - 3] + "..." if description and len(description) > description_width @@ -390,8 +403,7 @@ def _get_model_labels(self) -> dict[str, str]: if description else "" ) - result[key] = f"%-{label_width}s %s" % (models[key].name, description) - + result[x] = f"%-{label_width}s %s" % (models[x].name, description) return result def _get_columns(self) -> int: @@ -401,40 +413,50 @@ def _get_columns(self) -> int: def confirm_deletions(self, selections: InstallSelections) -> bool: remove_models = selections.remove_models - if remove_models: - model_names = [self.all_models[x].name or "" for x in remove_models] - mods = "\n".join(model_names) - is_ok = npyscreen.notify_ok_cancel( + if len(remove_models) > 0: + mods = "\n".join([ModelManager.parse_key(x)[0] for x in remove_models]) + return npyscreen.notify_ok_cancel( f"These unchecked models will be deleted from disk. Continue?\n---------\n{mods}" ) - assert isinstance(is_ok, bool) # npyscreen doesn't have return type annotations - return is_ok else: return True - @property - def all_models(self) -> Dict[str, UnifiedModelInfo]: - # npyscreen doesn't having typing hints - return self.parentApp.install_helper.all_models # type: ignore + def on_execute(self): + self.marshall_arguments() + app = self.parentApp + if not self.confirm_deletions(app.install_selections): + return - @property - def starter_models(self) -> List[str]: - return self.parentApp.install_helper._starter_models # type: ignore + self.monitor.entry_widget.buffer(["Processing..."], scroll_end=True) + self.ok_button.hidden = True + self.display() - @property - def installed_models(self) -> List[str]: - return self.parentApp.install_helper._installed_models # type: ignore + # TO DO: Spawn a worker thread, not a subprocess + parent_conn, child_conn = Pipe() + p = Process( + target=process_and_execute, + kwargs={ + "opt": app.program_opts, + "selections": app.install_selections, + "conn_out": child_conn, + }, + ) + p.start() + child_conn.close() + self.subprocess_connection = parent_conn + self.subprocess = p + app.install_selections = InstallSelections() - def on_back(self) -> None: + def on_back(self): self.parentApp.switchFormPrevious() self.editing = False - def on_cancel(self) -> None: + def on_cancel(self): self.parentApp.setNextForm(None) self.parentApp.user_cancelled = True self.editing = False - def on_done(self) -> None: + def on_done(self): self.marshall_arguments() if not self.confirm_deletions(self.parentApp.install_selections): return @@ -442,7 +464,77 @@ def on_done(self) -> None: self.parentApp.user_cancelled = False self.editing = False - def marshall_arguments(self) -> None: + ########## This routine monitors the child process that is performing model installation and removal ##### + def while_waiting(self): + """Called during idle periods. Main task is to update the Log Messages box with messages + from the child process that does the actual installation/removal""" + c = self.subprocess_connection + if not c: + return + + monitor_widget = self.monitor.entry_widget + while c.poll(): + try: + data = c.recv_bytes().decode("utf-8") + data.strip("\n") + + # processing child is requesting user input to select the + # right configuration file + if data.startswith("*need v2 config"): + _, model_path, *_ = data.split(":", 2) + self._return_v2_config(model_path) + + # processing child is done + elif data == "*done*": + self._close_subprocess_and_regenerate_form() + break + + # update the log message box + else: + data = make_printable(data) + data = data.replace("[A", "") + monitor_widget.buffer( + textwrap.wrap( + data, + width=monitor_widget.width, + subsequent_indent=" ", + ), + scroll_end=True, + ) + self.display() + except (EOFError, OSError): + self.subprocess_connection = None + + def _return_v2_config(self, model_path: str): + c = self.subprocess_connection + model_name = Path(model_path).name + message = select_stable_diffusion_config_file(model_name=model_name) + c.send_bytes(message.encode("utf-8")) + + def _close_subprocess_and_regenerate_form(self): + app = self.parentApp + self.subprocess_connection.close() + self.subprocess_connection = None + self.monitor.entry_widget.buffer(["** Action Complete **"]) + self.display() + + # rebuild the form, saving and restoring some of the fields that need to be preserved. + saved_messages = self.monitor.entry_widget.values + + app.main_form = app.addForm( + "MAIN", + addModelsForm, + name="Install Stable Diffusion Models", + multipage=self.multipage, + ) + app.switchForm("MAIN") + + app.main_form.monitor.entry_widget.values = saved_messages + app.main_form.monitor.entry_widget.buffer([""], scroll_end=True) + # app.main_form.pipeline_models['autoload_directory'].value = autoload_dir + # app.main_form.pipeline_models['autoscan_on_startup'].value = autoscan + + def marshall_arguments(self): """ Assemble arguments and store as attributes of the application: .starter_models: dict of model names to install from INITIAL_CONFIGURE.yaml @@ -472,24 +564,46 @@ def marshall_arguments(self) -> None: models_to_install = [x for x in selected if not self.all_models[x].installed] models_to_remove = [x for x in section["models"] if x not in selected and self.all_models[x].installed] selections.remove_models.extend(models_to_remove) - selections.install_models.extend([all_models[x] for x in models_to_install]) + selections.install_models.extend( + all_models[x].path or all_models[x].repo_id + for x in models_to_install + if all_models[x].path or all_models[x].repo_id + ) # models located in the 'download_ids" section for section in ui_sections: if downloads := section.get("download_ids"): - models = [UnifiedModelInfo(source=x) for x in downloads.value.split()] - selections.install_models.extend(models) - - -class AddModelApplication(npyscreen.NPSAppManaged): # type: ignore - def __init__(self, opt: Namespace, install_helper: InstallHelper): + selections.install_models.extend(downloads.value.split()) + + # NOT NEEDED - DONE IN BACKEND NOW + # # special case for the ipadapter_models. If any of the adapters are + # # chosen, then we add the corresponding encoder(s) to the install list. + # section = self.ipadapter_models + # if section.get("models_selected"): + # selected_adapters = [ + # self.all_models[section["models"][x]].name for x in section.get("models_selected").value + # ] + # encoders = [] + # if any(["sdxl" in x for x in selected_adapters]): + # encoders.append("ip_adapter_sdxl_image_encoder") + # if any(["sd15" in x for x in selected_adapters]): + # encoders.append("ip_adapter_sd_image_encoder") + # for encoder in encoders: + # key = f"any/clip_vision/{encoder}" + # repo_id = f"InvokeAI/{encoder}" + # if key not in self.all_models: + # selections.install_models.append(repo_id) + + +class AddModelApplication(npyscreen.NPSAppManaged): + def __init__(self, opt): super().__init__() self.program_opts = opt self.user_cancelled = False + # self.autoload_pending = True self.install_selections = InstallSelections() - self.install_helper = install_helper - def onStart(self) -> None: + def onStart(self): npyscreen.setTheme(npyscreen.Themes.DefaultTheme) self.main_form = self.addForm( "MAIN", @@ -499,62 +613,138 @@ def onStart(self) -> None: ) -def list_models(installer: ModelInstallServiceBase, model_type: ModelType): - """Print out all models of type model_type.""" - models = installer.record_store.search_by_attr(model_type=model_type) - print(f"Installed models of type `{model_type}`:") - for model in models: - path = (config.models_path / model.path).resolve() - print(f"{model.name:40}{model.base.value:5}{model.type.value:8}{model.format.value:12}{path}") +class StderrToMessage: + def __init__(self, connection: Connection): + self.connection = connection + + def write(self, data: str): + self.connection.send_bytes(data.encode("utf-8")) + + def flush(self): + pass # -------------------------------------------------------- -def select_and_download_models(opt: Namespace) -> None: - """Prompt user for install/delete selections and execute.""" - precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device())) - # unsure how to avoid a typing complaint in the next line: config.precision is an enumerated Literal - config.precision = precision # type: ignore - install_helper = InstallHelper(config, logger) - installer = install_helper.installer +def ask_user_for_prediction_type(model_path: Path, tui_conn: Connection = None) -> SchedulerPredictionType: + if tui_conn: + logger.debug("Waiting for user response...") + return _ask_user_for_pt_tui(model_path, tui_conn) + else: + return _ask_user_for_pt_cmdline(model_path) - if opt.list_models: - list_models(installer, opt.list_models) - elif opt.add or opt.delete: - selections = InstallSelections( - install_models=[UnifiedModelInfo(source=x) for x in (opt.add or [])], remove_models=opt.delete or [] - ) - install_helper.add_or_delete(selections) +def _ask_user_for_pt_cmdline(model_path: Path) -> Optional[SchedulerPredictionType]: + choices = [SchedulerPredictionType.Epsilon, SchedulerPredictionType.VPrediction, None] + print( + f""" +Please select the scheduler prediction type of the checkpoint named {model_path.name}: +[1] "epsilon" - most v1.5 models and v2 models trained on 512 pixel images +[2] "vprediction" - v2 models trained on 768 pixel images and a few v1.5 models +[3] Accept the best guess; you can fix it in the Web UI later +""" + ) + choice = None + ok = False + while not ok: + try: + choice = input("select [3]> ").strip() + if not choice: + return None + choice = choices[int(choice) - 1] + ok = True + except (ValueError, IndexError): + print(f"{choice} is not a valid choice") + except EOFError: + return + return choice + + +def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection) -> SchedulerPredictionType: + tui_conn.send_bytes(f"*need v2 config for:{model_path}".encode("utf-8")) + # note that we don't do any status checking here + response = tui_conn.recv_bytes().decode("utf-8") + if response is None: + return None + elif response == "epsilon": + return SchedulerPredictionType.epsilon + elif response == "v": + return SchedulerPredictionType.VPrediction + elif response == "guess": + return None + else: + return None - elif opt.default_only: - default_model = install_helper.default_model() - assert default_model is not None - selections = InstallSelections(install_models=[default_model]) - install_helper.add_or_delete(selections) +# -------------------------------------------------------- +def process_and_execute( + opt: Namespace, + selections: InstallSelections, + conn_out: Connection = None, +): + # need to reinitialize config in subprocess + config = InvokeAIAppConfig.get_config() + args = ["--root", opt.root] if opt.root else [] + config.parse_args(args) + + # set up so that stderr is sent to conn_out + if conn_out: + translator = StderrToMessage(conn_out) + sys.stderr = translator + sys.stdout = translator + logger = InvokeAILogger.get_logger() + logger.handlers.clear() + logger.addHandler(logging.StreamHandler(translator)) + + installer = ModelInstall(config, prediction_type_helper=lambda x: ask_user_for_prediction_type(x, conn_out)) + installer.install(selections) + + if conn_out: + conn_out.send_bytes("*done*".encode("utf-8")) + conn_out.close() + + +# -------------------------------------------------------- +def select_and_download_models(opt: Namespace): + precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device())) + config.precision = precision + installer = ModelInstall(config, prediction_type_helper=ask_user_for_prediction_type) + if opt.list_models: + installer.list_models(opt.list_models) + elif opt.add or opt.delete: + selections = InstallSelections(install_models=opt.add or [], remove_models=opt.delete or []) + installer.install(selections) + elif opt.default_only: + selections = InstallSelections(install_models=installer.default_model()) + installer.install(selections) elif opt.yes_to_all: - selections = InstallSelections(install_models=install_helper.recommended_models()) - install_helper.add_or_delete(selections) + selections = InstallSelections(install_models=installer.recommended_models()) + installer.install(selections) # this is where the TUI is called else: + # needed to support the probe() method running under a subprocess + torch.multiprocessing.set_start_method("spawn") + if not set_min_terminal_size(MIN_COLS, MIN_LINES): raise WindowTooSmallException( "Could not increase terminal size. Try running again with a larger window or smaller font size." ) - installApp = AddModelApplication(opt, install_helper) + installApp = AddModelApplication(opt) try: installApp.run() - except KeyboardInterrupt: - print("Aborted...") - sys.exit(-1) - - install_helper.add_or_delete(installApp.install_selections) + except KeyboardInterrupt as e: + if hasattr(installApp, "main_form"): + if installApp.main_form.subprocess and installApp.main_form.subprocess.is_alive(): + logger.info("Terminating subprocesses") + installApp.main_form.subprocess.terminate() + installApp.main_form.subprocess = None + raise e + process_and_execute(opt, installApp.install_selections) # ------------------------------------- -def main() -> None: +def main(): parser = argparse.ArgumentParser(description="InvokeAI model downloader") parser.add_argument( "--add", @@ -564,7 +754,7 @@ def main() -> None: parser.add_argument( "--delete", nargs="*", - help="List of names of models to delete. Use type:name to disambiguate, as in `controlnet:my_model`", + help="List of names of models to idelete", ) parser.add_argument( "--full-precision", @@ -591,6 +781,14 @@ def main() -> None: choices=[x.value for x in ModelType], help="list installed models", ) + parser.add_argument( + "--config_file", + "-c", + dest="config_file", + type=str, + default=None, + help="path to configuration file to create", + ) parser.add_argument( "--root_dir", dest="root", diff --git a/invokeai/frontend/install/widgets.py b/invokeai/frontend/install/widgets.py index 5905ae29dab..4dbc6349a0b 100644 --- a/invokeai/frontend/install/widgets.py +++ b/invokeai/frontend/install/widgets.py @@ -267,6 +267,17 @@ def h_select(self, ch): self.on_changed(self.value) +class CheckboxWithChanged(npyscreen.Checkbox): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.on_changed = None + + def whenToggled(self): + super().whenToggled() + if self.on_changed: + self.on_changed(self.value) + + class SingleSelectColumnsSimple(SelectColumnBase, SingleSelectWithChanged): """Row of radio buttons. Spacebar to select.""" diff --git a/invokeai/frontend/merge/merge_diffusers2.py b/invokeai/frontend/merge/merge_diffusers.py.OLD similarity index 100% rename from invokeai/frontend/merge/merge_diffusers2.py rename to invokeai/frontend/merge/merge_diffusers.py.OLD diff --git a/pyproject.toml b/pyproject.toml index 8b28375e291..2958e3629a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -136,8 +136,7 @@ dependencies = [ # full commands "invokeai-configure" = "invokeai.frontend.install.invokeai_configure:invokeai_configure" -"invokeai-merge" = "invokeai.frontend.merge:invokeai_merge_diffusers" -"invokeai-merge2" = "invokeai.frontend.merge.merge_diffusers2:main" +"invokeai-merge" = "invokeai.frontend.merge.merge_diffusers:main" "invokeai-ti" = "invokeai.frontend.training:invokeai_textual_inversion" "invokeai-model-install" = "invokeai.frontend.install.model_install:main" "invokeai-model-install2" = "invokeai.frontend.install.model_install2:main" # will eventually be renamed to invokeai-model-install diff --git a/tests/test_model_manager.py b/tests/test_model_manager.py deleted file mode 100644 index 3e48c7ed6fc..00000000000 --- a/tests/test_model_manager.py +++ /dev/null @@ -1,47 +0,0 @@ -from pathlib import Path - -import pytest - -from invokeai.app.services.config.config_default import InvokeAIAppConfig -from invokeai.backend import BaseModelType, ModelManager, ModelType, SubModelType - -BASIC_MODEL_NAME = ("SDXL base", BaseModelType.StableDiffusionXL, ModelType.Main) -VAE_OVERRIDE_MODEL_NAME = ("SDXL with VAE", BaseModelType.StableDiffusionXL, ModelType.Main) -VAE_NULL_OVERRIDE_MODEL_NAME = ("SDXL with empty VAE", BaseModelType.StableDiffusionXL, ModelType.Main) - - -@pytest.fixture -def model_manager(datadir) -> ModelManager: - InvokeAIAppConfig.get_config(root=datadir) - return ModelManager(datadir / "configs" / "relative_sub.models.yaml") - - -def test_get_model_names(model_manager: ModelManager): - names = model_manager.model_names() - assert names[:2] == [BASIC_MODEL_NAME, VAE_OVERRIDE_MODEL_NAME] - - -def test_get_model_path_for_diffusers(model_manager: ModelManager, datadir: Path): - model_config = model_manager._get_model_config(BASIC_MODEL_NAME[1], BASIC_MODEL_NAME[0], BASIC_MODEL_NAME[2]) - top_model_path, is_override = model_manager._get_model_path(model_config) - expected_model_path = datadir / "models" / "sdxl" / "main" / "SDXL base 1_0" - assert top_model_path == expected_model_path - assert not is_override - - -def test_get_model_path_for_overridden_vae(model_manager: ModelManager, datadir: Path): - model_config = model_manager._get_model_config( - VAE_OVERRIDE_MODEL_NAME[1], VAE_OVERRIDE_MODEL_NAME[0], VAE_OVERRIDE_MODEL_NAME[2] - ) - vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae) - expected_vae_path = datadir / "models" / "sdxl" / "vae" / "sdxl-vae-fp16-fix" - assert vae_model_path == expected_vae_path - assert is_override - - -def test_get_model_path_for_null_overridden_vae(model_manager: ModelManager, datadir: Path): - model_config = model_manager._get_model_config( - VAE_NULL_OVERRIDE_MODEL_NAME[1], VAE_NULL_OVERRIDE_MODEL_NAME[0], VAE_NULL_OVERRIDE_MODEL_NAME[2] - ) - vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae) - assert not is_override From f7e558d16555331e6f4862cea4117460caef9275 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Fri, 9 Feb 2024 20:46:47 -0500 Subject: [PATCH 080/100] probe for required encoder for IPAdapters and add to config --- invokeai/app/invocations/ip_adapter.py | 24 +----------------------- invokeai/backend/model_manager/config.py | 1 + invokeai/backend/model_manager/probe.py | 13 +++++++++++++ 3 files changed, 15 insertions(+), 23 deletions(-) diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index 700b285a45f..f64b3266bbb 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -1,4 +1,3 @@ -import os from builtins import float from typing import List, Union @@ -52,16 +51,6 @@ def validate_begin_end_step_percent(self) -> Self: return self -def get_ip_adapter_image_encoder_model_id(model_path: str): - """Read the ID of the image encoder associated with the IP-Adapter at `model_path`.""" - image_encoder_config_file = os.path.join(model_path, "image_encoder.txt") - - with open(image_encoder_config_file, "r") as f: - image_encoder_model = f.readline().strip() - - return image_encoder_model - - @invocation_output("ip_adapter_output") class IPAdapterOutput(BaseInvocationOutput): # Outputs @@ -102,18 +91,7 @@ def validate_begin_end_step_percent(self) -> Self: def invoke(self, context: InvocationContext) -> IPAdapterOutput: # Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model. ip_adapter_info = context.services.model_records.get_model(self.ip_adapter_model.key) - # HACK(ryand): This is bad for a couple of reasons: 1) we are bypassing the model manager to read the model - # directly, and 2) we are reading from disk every time this invocation is called without caching the result. - # A better solution would be to store the image encoder model reference in the IP-Adapter model info, but this - # is currently messy due to differences between how the model info is generated when installing a model from - # disk vs. downloading the model. - # TODO (LS): Fix the issue above by: - # 1. Change IPAdapterConfig definition to include a field for the repo_id of the image encoder model. - # 2. Update probe.py to read `image_encoder.txt` and store it in the config. - # 3. Change below to get the image encoder from the configuration record. - image_encoder_model_id = get_ip_adapter_image_encoder_model_id( - os.path.join(context.services.configuration.get_config().models_path, ip_adapter_info.path) - ) + image_encoder_model_id = ip_adapter_info.image_encoder_model_id image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip() image_encoder_models = context.services.model_records.search_by_attr( model_name=image_encoder_model_name, base_model=BaseModelType.Any, model_type=ModelType.CLIPVision diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 0dcd925c84b..d2e7a0923a4 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -263,6 +263,7 @@ class IPAdapterConfig(ModelConfigBase): """Model config for IP Adaptor format models.""" type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter + image_encoder_model_id: str format: Literal[ModelFormat.InvokeAI] diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 55a9c0464a5..e7d21c578fd 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -78,6 +78,10 @@ def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]: """Get model scheduler prediction type.""" return None + def get_image_encoder_model_id(self) -> Optional[str]: + """Get image encoder (IP adapters only).""" + return None + class ModelProbe(object): PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = { @@ -153,6 +157,7 @@ def probe( fields["base"] = fields.get("base") or probe.get_base_type() fields["variant"] = fields.get("variant") or probe.get_variant_type() fields["prediction_type"] = fields.get("prediction_type") or probe.get_scheduler_prediction_type() + fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id() fields["name"] = fields.get("name") or cls.get_model_name(model_path) fields["description"] = ( fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}" @@ -669,6 +674,14 @@ def get_base_type(self) -> BaseModelType: f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}." ) + def get_image_encoder_model_id(self) -> Optional[str]: + encoder_id_path = self.model_path / "image_encoder.txt" + if not encoder_id_path.exists(): + return None + with open(encoder_id_path, "r") as f: + image_encoder_model = f.readline().strip() + return image_encoder_model + class CLIPVisionFolderProbe(FolderProbeBase): def get_base_type(self) -> BaseModelType: From 2b1dc74080a0bf77891fd0bed4ad9b984b950dd0 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Fri, 9 Feb 2024 23:08:38 -0500 Subject: [PATCH 081/100] consolidate model manager parts into a single class --- invokeai/app/services/model_load/__init__.py | 6 + .../services/model_load/model_load_base.py | 22 + .../services/model_load/model_load_default.py | 54 +++ .../app/services/model_manager/__init__.py | 17 +- .../model_manager/model_manager_base.py | 294 +---------- .../model_manager/model_manager_default.py | 456 ++---------------- invokeai/backend/__init__.py | 9 - invokeai/backend/model_manager/config.py | 6 +- .../backend/model_manager/load/__init__.py | 2 +- invokeai/backend/model_manager/search.py | 12 +- 10 files changed, 184 insertions(+), 694 deletions(-) create mode 100644 invokeai/app/services/model_load/__init__.py create mode 100644 invokeai/app/services/model_load/model_load_base.py create mode 100644 invokeai/app/services/model_load/model_load_default.py diff --git a/invokeai/app/services/model_load/__init__.py b/invokeai/app/services/model_load/__init__.py new file mode 100644 index 00000000000..b4a86e9348d --- /dev/null +++ b/invokeai/app/services/model_load/__init__.py @@ -0,0 +1,6 @@ +"""Initialization file for model load service module.""" + +from .model_load_base import ModelLoadServiceBase +from .model_load_default import ModelLoadService + +__all__ = ["ModelLoadServiceBase", "ModelLoadService"] diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py new file mode 100644 index 00000000000..7228806e809 --- /dev/null +++ b/invokeai/app/services/model_load/model_load_base.py @@ -0,0 +1,22 @@ +# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team +"""Base class for model loader.""" + +from abc import ABC, abstractmethod +from typing import Optional + +from invokeai.backend.model_manager import AnyModelConfig, SubModelType +from invokeai.backend.model_manager.load import LoadedModel + + +class ModelLoadServiceBase(ABC): + """Wrapper around AnyModelLoader.""" + + @abstractmethod + def load_model_by_key(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel: + """Given a model's key, load it and return the LoadedModel object.""" + pass + + @abstractmethod + def load_model_by_config(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: + """Given a model's configuration, load it and return the LoadedModel object.""" + pass diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py new file mode 100644 index 00000000000..80e2fe161d0 --- /dev/null +++ b/invokeai/app/services/model_load/model_load_default.py @@ -0,0 +1,54 @@ +# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team +"""Implementation of model loader service.""" + +from typing import Optional + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.model_records import ModelRecordServiceBase +from invokeai.backend.model_manager import AnyModelConfig, SubModelType +from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel, ModelCache, ModelConvertCache +from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase +from invokeai.backend.model_manager.load.ram_cache import ModelCacheBase +from invokeai.backend.util.logging import InvokeAILogger + +from .model_load_base import ModelLoadServiceBase + + +class ModelLoadService(ModelLoadServiceBase): + """Wrapper around AnyModelLoader.""" + + def __init__( + self, + app_config: InvokeAIAppConfig, + record_store: ModelRecordServiceBase, + ram_cache: Optional[ModelCacheBase] = None, + convert_cache: Optional[ModelConvertCacheBase] = None, + ): + """Initialize the model load service.""" + logger = InvokeAILogger.get_logger(self.__class__.__name__) + logger.setLevel(app_config.log_level.upper()) + self._store = record_store + self._any_loader = AnyModelLoader( + app_config=app_config, + logger=logger, + ram_cache=ram_cache + or ModelCache( + max_cache_size=app_config.ram_cache_size, + max_vram_cache_size=app_config.vram_cache_size, + logger=logger, + ), + convert_cache=convert_cache + or ModelConvertCache( + cache_path=app_config.models_convert_cache_path, + max_size=app_config.convert_cache_size, + ), + ) + + def load_model_by_key(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel: + """Given a model's key, load it and return the LoadedModel object.""" + config = self._store.get_model(key) + return self.load_model_by_config(config, submodel_type) + + def load_model_by_config(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: + """Given a model's configuration, load it and return the LoadedModel object.""" + return self._any_loader.load_model(config, submodel_type) diff --git a/invokeai/app/services/model_manager/__init__.py b/invokeai/app/services/model_manager/__init__.py index 3d6a9c248c6..5e281922a8b 100644 --- a/invokeai/app/services/model_manager/__init__.py +++ b/invokeai/app/services/model_manager/__init__.py @@ -1 +1,16 @@ -from .model_manager_default import ModelManagerService # noqa F401 +"""Initialization file for model manager service.""" + +from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType +from invokeai.backend.model_manager.load import LoadedModel + +from .model_manager_default import ModelManagerService + +__all__ = [ + "ModelManagerService", + "AnyModel", + "AnyModelConfig", + "BaseModelType", + "ModelType", + "SubModelType", + "LoadedModel", +] diff --git a/invokeai/app/services/model_manager/model_manager_base.py b/invokeai/app/services/model_manager/model_manager_base.py index f888c0ec973..c6e77fa163d 100644 --- a/invokeai/app/services/model_manager/model_manager_base.py +++ b/invokeai/app/services/model_manager/model_manager_base.py @@ -1,283 +1,39 @@ # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team -from __future__ import annotations - from abc import ABC, abstractmethod -from logging import Logger -from pathlib import Path -from typing import Callable, List, Literal, Optional, Tuple, Union - -from pydantic import Field - -from invokeai.app.services.config.config_default import InvokeAIAppConfig -from invokeai.app.services.shared.invocation_context import InvocationContextData -from invokeai.backend.model_management import ( - AddModelResult, - BaseModelType, - LoadedModelInfo, - MergeInterpolationMethod, - ModelType, - SchedulerPredictionType, - SubModelType, -) -from invokeai.backend.model_management.model_cache import CacheStats - - -class ModelManagerServiceBase(ABC): - """Responsible for managing models on disk and in memory""" - - @abstractmethod - def __init__( - self, - config: InvokeAIAppConfig, - logger: Logger, - ): - """ - Initialize with the path to the models.yaml config file. - Optional parameters are the torch device type, precision, max_models, - and sequential_offload boolean. Note that the default device - type and precision are set up for a CUDA system running at half precision. - """ - pass - @abstractmethod - def get_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: Optional[SubModelType] = None, - context_data: Optional[InvocationContextData] = None, - ) -> LoadedModelInfo: - """Retrieve the indicated model with name and type. - submodel can be used to get a part (such as the vae) - of a diffusers pipeline.""" - pass - - @property - @abstractmethod - def logger(self): - pass - - @abstractmethod - def model_exists( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - ) -> bool: - pass - - @abstractmethod - def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: - """ - Given a model name returns a dict-like (OmegaConf) object describing it. - Uses the exact format as the omegaconf stanza. - """ - pass +from pydantic import BaseModel, Field +from typing_extensions import Self - @abstractmethod - def list_models(self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None) -> dict: - """ - Return a dict of models in the format: - { model_type1: - { model_name1: {'status': 'active'|'cached'|'not loaded', - 'model_name' : name, - 'model_type' : SDModelType, - 'description': description, - 'format': 'folder'|'safetensors'|'ckpt' - }, - model_name2: { etc } - }, - model_type2: - { model_name_n: etc - } - """ - pass +from ..config import InvokeAIAppConfig +from ..download import DownloadQueueServiceBase +from ..events.events_base import EventServiceBase +from ..model_install import ModelInstallServiceBase +from ..model_load import ModelLoadServiceBase +from ..model_records import ModelRecordServiceBase +from ..shared.sqlite.sqlite_database import SqliteDatabase - @abstractmethod - def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: - """ - Return information about the model using the same format as list_models() - """ - pass - @abstractmethod - def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]: - """ - Returns a list of all the model names known. - """ - pass +class ModelManagerServiceBase(BaseModel, ABC): + """Abstract base class for the model manager service.""" - @abstractmethod - def add_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - model_attributes: dict, - clobber: bool = False, - ) -> AddModelResult: - """ - Update the named model with a dictionary of attributes. Will fail with an - assertion error if the name already exists. Pass clobber=True to overwrite. - On a successful update, the config will be changed in memory. Will fail - with an assertion error if provided attributes are incorrect or - the model name is missing. Call commit() to write changes to disk. - """ - pass + store: ModelRecordServiceBase = Field(description="An instance of the model record configuration service.") + install: ModelInstallServiceBase = Field(description="An instance of the model install service.") + load: ModelLoadServiceBase = Field(description="An instance of the model load service.") + @classmethod @abstractmethod - def update_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - model_attributes: dict, - ) -> AddModelResult: + def build_model_manager( + cls, + app_config: InvokeAIAppConfig, + db: SqliteDatabase, + download_queue: DownloadQueueServiceBase, + events: EventServiceBase, + ) -> Self: """ - Update the named model with a dictionary of attributes. Will fail with a - ModelNotFoundException if the name does not already exist. + Construct the model manager service instance. - On a successful update, the config will be changed in memory. Will fail - with an assertion error if provided attributes are incorrect or - the model name is missing. Call commit() to write changes to disk. - """ - pass - - @abstractmethod - def del_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - ): - """ - Delete the named model from configuration. If delete_files is true, - then the underlying weight file or diffusers directory will be deleted - as well. Call commit() to write to disk. - """ - pass - - @abstractmethod - def rename_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - new_name: str, - ): - """ - Rename the indicated model. - """ - pass - - @abstractmethod - def list_checkpoint_configs(self) -> List[Path]: - """ - List the checkpoint config paths from ROOT/configs/stable-diffusion. - """ - pass - - @abstractmethod - def convert_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: Literal[ModelType.Main, ModelType.Vae], - ) -> AddModelResult: - """ - Convert a checkpoint file into a diffusers folder, deleting the cached - version and deleting the original checkpoint file if it is in the models - directory. - :param model_name: Name of the model to convert - :param base_model: Base model type - :param model_type: Type of model ['vae' or 'main'] - - This will raise a ValueError unless the model is not a checkpoint. It will - also raise a ValueError in the event that there is a similarly-named diffusers - directory already in place. - """ - pass - - @abstractmethod - def heuristic_import( - self, - items_to_import: set[str], - prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None, - ) -> dict[str, AddModelResult]: - """Import a list of paths, repo_ids or URLs. Returns the set of - successfully imported items. - :param items_to_import: Set of strings corresponding to models to be imported. - :param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType. - - The prediction type helper is necessary to distinguish between - models based on Stable Diffusion 2 Base (requiring - SchedulerPredictionType.Epsilson) and Stable Diffusion 768 - (requiring SchedulerPredictionType.VPrediction). It is - generally impossible to do this programmatically, so the - prediction_type_helper usually asks the user to choose. - - The result is a set of successfully installed models. Each element - of the set is a dict corresponding to the newly-created OmegaConf stanza for - that model. - """ - pass - - @abstractmethod - def merge_models( - self, - model_names: List[str] = Field( - default=None, min_length=2, max_length=3, description="List of model names to merge" - ), - base_model: Union[BaseModelType, str] = Field( - default=None, description="Base model shared by all models to be merged" - ), - merged_model_name: str = Field(default=None, description="Name of destination model after merging"), - alpha: Optional[float] = 0.5, - interp: Optional[MergeInterpolationMethod] = None, - force: Optional[bool] = False, - merge_dest_directory: Optional[Path] = None, - ) -> AddModelResult: - """ - Merge two to three diffusrs pipeline models and save as a new model. - :param model_names: List of 2-3 models to merge - :param base_model: Base model to use for all models - :param merged_model_name: Name of destination merged model - :param alpha: Alpha strength to apply to 2d and 3d model - :param interp: Interpolation method. None (default) - :param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended) - """ - pass - - @abstractmethod - def search_for_models(self, directory: Path) -> List[Path]: - """ - Return list of all models found in the designated directory. - """ - pass - - @abstractmethod - def sync_to_config(self): - """ - Re-read models.yaml, rescan the models directory, and reimport models - in the autoimport directories. Call after making changes outside the - model manager API. - """ - pass - - @abstractmethod - def collect_cache_stats(self, cache_stats: CacheStats): - """ - Reset model cache statistics for graph with graph_id. - """ - pass - - @abstractmethod - def commit(self, conf_file: Optional[Path] = None) -> None: - """ - Write current configuration out to the indicated file. - If no conf_file is provided, then replaces the - original file/database used to initialize the object. + Use it rather than the __init__ constructor. This class + method simplifies the construction considerably. """ pass diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index c3712abf8e6..ad0fd66dbbd 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -1,421 +1,67 @@ # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team +"""Implementation of ModelManagerServiceBase.""" -from __future__ import annotations +from typing_extensions import Self -from logging import Logger -from pathlib import Path -from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union - -import torch -from pydantic import Field - -from invokeai.app.services.config.config_default import InvokeAIAppConfig -from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException -from invokeai.app.services.invoker import Invoker -from invokeai.app.services.shared.invocation_context import InvocationContextData -from invokeai.backend.model_management import ( - AddModelResult, - BaseModelType, - LoadedModelInfo, - MergeInterpolationMethod, - ModelManager, - ModelMerger, - ModelNotFoundException, - ModelType, - SchedulerPredictionType, - SubModelType, -) -from invokeai.backend.model_management.model_cache import CacheStats -from invokeai.backend.model_management.model_search import FindModels -from invokeai.backend.util import choose_precision, choose_torch_device +from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache +from invokeai.backend.model_manager.metadata import ModelMetadataStore +from invokeai.backend.util.logging import InvokeAILogger +from ..config import InvokeAIAppConfig +from ..download import DownloadQueueServiceBase +from ..events.events_base import EventServiceBase +from ..model_install import ModelInstallService +from ..model_load import ModelLoadService +from ..model_records import ModelRecordServiceSQL +from ..shared.sqlite.sqlite_database import SqliteDatabase from .model_manager_base import ModelManagerServiceBase -if TYPE_CHECKING: - pass - -# simple implementation class ModelManagerService(ModelManagerServiceBase): - """Responsible for managing models on disk and in memory""" - - def __init__( - self, - config: InvokeAIAppConfig, - logger: Logger, - ): - """ - Initialize with the path to the models.yaml config file. - Optional parameters are the torch device type, precision, max_models, - and sequential_offload boolean. Note that the default device - type and precision are set up for a CUDA system running at half precision. - """ - if config.model_conf_path and config.model_conf_path.exists(): - config_file = config.model_conf_path - else: - config_file = config.root_dir / "configs/models.yaml" - - logger.debug(f"Config file={config_file}") - - device = torch.device(choose_torch_device()) - device_name = torch.cuda.get_device_name() if device == torch.device("cuda") else "" - logger.info(f"GPU device = {device} {device_name}") - - precision = config.precision - if precision == "auto": - precision = choose_precision(device) - dtype = torch.float32 if precision == "float32" else torch.float16 - - # this is transitional backward compatibility - # support for the deprecated `max_loaded_models` - # configuration value. If present, then the - # cache size is set to 2.5 GB times - # the number of max_loaded_models. Otherwise - # use new `ram_cache_size` config setting - max_cache_size = config.ram_cache_size + """ + The ModelManagerService handles various aspects of model installation, maintenance and loading. - logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB") + It bundles three distinct services: + model_manager.store -- Routines to manage the database of model configuration records. + model_manager.install -- Routines to install, move and delete models. + model_manager.load -- Routines to load models into memory. + """ - sequential_offload = config.sequential_guidance - - self.mgr = ModelManager( - config=config_file, - device_type=device, - precision=dtype, - max_cache_size=max_cache_size, - sequential_offload=sequential_offload, - logger=logger, - ) - logger.info("Model manager service initialized") - - def start(self, invoker: Invoker) -> None: - self._invoker: Optional[Invoker] = invoker - - def get_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: Optional[SubModelType] = None, - context_data: Optional[InvocationContextData] = None, - ) -> LoadedModelInfo: - """ - Retrieve the indicated model. submodel can be used to get a - part (such as the vae) of a diffusers mode. + @classmethod + def build_model_manager( + cls, + app_config: InvokeAIAppConfig, + db: SqliteDatabase, + download_queue: DownloadQueueServiceBase, + events: EventServiceBase, + ) -> Self: """ + Construct the model manager service instance. - # we can emit model loading events if we are executing with access to the invocation context - if context_data is not None: - self._emit_load_event( - context_data=context_data, - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=submodel, - ) + For simplicity, use this class method rather than the __init__ constructor. + """ + logger = InvokeAILogger.get_logger(cls.__name__) + logger.setLevel(app_config.log_level.upper()) - loaded_model_info = self.mgr.get_model( - model_name, - base_model, - model_type, - submodel, + ram_cache = ModelCache( + max_cache_size=app_config.ram_cache_size, max_vram_cache_size=app_config.vram_cache_size, logger=logger ) - - if context_data is not None: - self._emit_load_event( - context_data=context_data, - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=submodel, - loaded_model_info=loaded_model_info, - ) - - return loaded_model_info - - def model_exists( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - ) -> bool: - """ - Given a model name, returns True if it is a valid - identifier. - """ - return self.mgr.model_exists( - model_name, - base_model, - model_type, + convert_cache = ModelConvertCache( + cache_path=app_config.models_convert_cache_path, max_size=app_config.convert_cache_size ) - - def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]: - """ - Given a model name returns a dict-like (OmegaConf) object describing it. - """ - return self.mgr.model_info(model_name, base_model, model_type) - - def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]: - """ - Returns a list of all the model names known. - """ - return self.mgr.model_names() - - def list_models( - self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None - ) -> list[dict]: - """ - Return a list of models. - """ - return self.mgr.list_models(base_model, model_type) - - def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]: - """ - Return information about the model using the same format as list_models() - """ - return self.mgr.list_model(model_name=model_name, base_model=base_model, model_type=model_type) - - def add_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - model_attributes: dict, - clobber: bool = False, - ) -> AddModelResult: - """ - Update the named model with a dictionary of attributes. Will fail with an - assertion error if the name already exists. Pass clobber=True to overwrite. - On a successful update, the config will be changed in memory. Will fail - with an assertion error if provided attributes are incorrect or - the model name is missing. Call commit() to write changes to disk. - """ - self.logger.debug(f"add/update model {model_name}") - return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber) - - def update_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - model_attributes: dict, - ) -> AddModelResult: - """ - Update the named model with a dictionary of attributes. Will fail with a - ModelNotFoundException exception if the name does not already exist. - On a successful update, the config will be changed in memory. Will fail - with an assertion error if provided attributes are incorrect or - the model name is missing. Call commit() to write changes to disk. - """ - self.logger.debug(f"update model {model_name}") - if not self.model_exists(model_name, base_model, model_type): - raise ModelNotFoundException(f"Unknown model {model_name}") - return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True) - - def del_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - ): - """ - Delete the named model from configuration. If delete_files is true, - then the underlying weight file or diffusers directory will be deleted - as well. - """ - self.logger.debug(f"delete model {model_name}") - self.mgr.del_model(model_name, base_model, model_type) - self.mgr.commit() - - def convert_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: Literal[ModelType.Main, ModelType.Vae], - convert_dest_directory: Optional[Path] = Field( - default=None, description="Optional directory location for merged model" - ), - ) -> AddModelResult: - """ - Convert a checkpoint file into a diffusers folder, deleting the cached - version and deleting the original checkpoint file if it is in the models - directory. - :param model_name: Name of the model to convert - :param base_model: Base model type - :param model_type: Type of model ['vae' or 'main'] - :param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default) - - This will raise a ValueError unless the model is not a checkpoint. It will - also raise a ValueError in the event that there is a similarly-named diffusers - directory already in place. - """ - self.logger.debug(f"convert model {model_name}") - return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory) - - def collect_cache_stats(self, cache_stats: CacheStats): - """ - Reset model cache statistics for graph with graph_id. - """ - self.mgr.cache.stats = cache_stats - - def commit(self, conf_file: Optional[Path] = None): - """ - Write current configuration out to the indicated file. - If no conf_file is provided, then replaces the - original file/database used to initialize the object. - """ - return self.mgr.commit(conf_file) - - def _emit_load_event( - self, - context_data: InvocationContextData, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: Optional[SubModelType] = None, - loaded_model_info: Optional[LoadedModelInfo] = None, - ): - if self._invoker is None: - return - - if self._invoker.services.queue.is_canceled(context_data.session_id): - raise CanceledException() - - if loaded_model_info: - self._invoker.services.events.emit_model_load_completed( - queue_id=context_data.queue_id, - queue_item_id=context_data.queue_item_id, - queue_batch_id=context_data.batch_id, - graph_execution_state_id=context_data.session_id, - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=submodel, - loaded_model_info=loaded_model_info, - ) - else: - self._invoker.services.events.emit_model_load_started( - queue_id=context_data.queue_id, - queue_item_id=context_data.queue_item_id, - queue_batch_id=context_data.batch_id, - graph_execution_state_id=context_data.session_id, - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=submodel, - ) - - @property - def logger(self): - return self.mgr.logger - - def heuristic_import( - self, - items_to_import: set[str], - prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None, - ) -> dict[str, AddModelResult]: - """Import a list of paths, repo_ids or URLs. Returns the set of - successfully imported items. - :param items_to_import: Set of strings corresponding to models to be imported. - :param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType. - - The prediction type helper is necessary to distinguish between - models based on Stable Diffusion 2 Base (requiring - SchedulerPredictionType.Epsilson) and Stable Diffusion 768 - (requiring SchedulerPredictionType.VPrediction). It is - generally impossible to do this programmatically, so the - prediction_type_helper usually asks the user to choose. - - The result is a set of successfully installed models. Each element - of the set is a dict corresponding to the newly-created OmegaConf stanza for - that model. - """ - return self.mgr.heuristic_import(items_to_import, prediction_type_helper) - - def merge_models( - self, - model_names: List[str] = Field( - default=None, min_length=2, max_length=3, description="List of model names to merge" - ), - base_model: Union[BaseModelType, str] = Field( - default=None, description="Base model shared by all models to be merged" - ), - merged_model_name: str = Field(default=None, description="Name of destination model after merging"), - alpha: float = 0.5, - interp: Optional[MergeInterpolationMethod] = None, - force: bool = False, - merge_dest_directory: Optional[Path] = Field( - default=None, description="Optional directory location for merged model" - ), - ) -> AddModelResult: - """ - Merge two to three diffusrs pipeline models and save as a new model. - :param model_names: List of 2-3 models to merge - :param base_model: Base model to use for all models - :param merged_model_name: Name of destination merged model - :param alpha: Alpha strength to apply to 2d and 3d model - :param interp: Interpolation method. None (default) - :param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended) - """ - merger = ModelMerger(self.mgr) - try: - result = merger.merge_diffusion_models_and_save( - model_names=model_names, - base_model=base_model, - merged_model_name=merged_model_name, - alpha=alpha, - interp=interp, - force=force, - merge_dest_directory=merge_dest_directory, - ) - except AssertionError as e: - raise ValueError(e) - return result - - def search_for_models(self, directory: Path) -> List[Path]: - """ - Return list of all models found in the designated directory. - """ - search = FindModels([directory], self.logger) - return search.list_models() - - def sync_to_config(self): - """ - Re-read models.yaml, rescan the models directory, and reimport models - in the autoimport directories. Call after making changes outside the - model manager API. - """ - return self.mgr.sync_to_config() - - def list_checkpoint_configs(self) -> List[Path]: - """ - List the checkpoint config paths from ROOT/configs/stable-diffusion. - """ - config = self.mgr.app_config - conf_path = config.legacy_conf_path - root_path = config.root_path - return [(conf_path / x).relative_to(root_path) for x in conf_path.glob("**/*.yaml")] - - def rename_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - new_name: Optional[str] = None, - new_base: Optional[BaseModelType] = None, - ): - """ - Rename the indicated model. Can provide a new name and/or a new base. - :param model_name: Current name of the model - :param base_model: Current base of the model - :param model_type: Model type (can't be changed) - :param new_name: New name for the model - :param new_base: New base for the model - """ - self.mgr.rename_model( - base_model=base_model, - model_type=model_type, - model_name=model_name, - new_name=new_name, - new_base=new_base, + record_store = ModelRecordServiceSQL(db=db) + loader = ModelLoadService( + app_config=app_config, + record_store=record_store, + ram_cache=ram_cache, + convert_cache=convert_cache, + ) + record_store._loader = loader # yeah, there is a circular reference here + installer = ModelInstallService( + app_config=app_config, + record_store=record_store, + download_queue=download_queue, + metadata_store=ModelMetadataStore(db=db), + event_bus=events, ) + return cls(store=record_store, install=installer, load=loader) diff --git a/invokeai/backend/__init__.py b/invokeai/backend/__init__.py index 54a1843d463..9fe97ee525e 100644 --- a/invokeai/backend/__init__.py +++ b/invokeai/backend/__init__.py @@ -1,12 +1,3 @@ """ Initialization file for invokeai.backend """ -from .model_management import ( # noqa: F401 - BaseModelType, - LoadedModelInfo, - ModelCache, - ModelManager, - ModelType, - SubModelType, -) -from .model_management.models import SilenceWarnings # noqa: F401 diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index d2e7a0923a4..4534a4892fb 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -21,7 +21,7 @@ """ import time from enum import Enum -from typing import Literal, Optional, Type, Union +from typing import Literal, Optional, Type, Union, Class import torch from diffusers import ModelMixin @@ -333,9 +333,9 @@ class ModelConfigFactory(object): @classmethod def make_config( cls, - model_data: Union[dict, AnyModelConfig], + model_data: Union[Dict[str, Any], AnyModelConfig], key: Optional[str] = None, - dest_class: Optional[Type] = None, + dest_class: Optional[Type[Class]] = None, timestamp: Optional[float] = None, ) -> AnyModelConfig: """ diff --git a/invokeai/backend/model_manager/load/__init__.py b/invokeai/backend/model_manager/load/__init__.py index e4c7077f783..966a739237a 100644 --- a/invokeai/backend/model_manager/load/__init__.py +++ b/invokeai/backend/model_manager/load/__init__.py @@ -18,7 +18,7 @@ for module in loaders: import_module(f"{__package__}.model_loaders.{module}") -__all__ = ["AnyModelLoader", "LoadedModel"] +__all__ = ["AnyModelLoader", "LoadedModel", "ModelCache", "ModelConvertCache"] def get_standalone_loader(app_config: Optional[InvokeAIAppConfig]) -> AnyModelLoader: diff --git a/invokeai/backend/model_manager/search.py b/invokeai/backend/model_manager/search.py index 4cc3caebe47..a54938fdd5c 100644 --- a/invokeai/backend/model_manager/search.py +++ b/invokeai/backend/model_manager/search.py @@ -26,10 +26,10 @@ def find_main_models(model: Path) -> bool: from typing import Callable, Optional, Set, Union from pydantic import BaseModel, Field - +from logging import Logger from invokeai.backend.util.logging import InvokeAILogger -default_logger = InvokeAILogger.get_logger() +default_logger: Logger = InvokeAILogger.get_logger() class SearchStats(BaseModel): @@ -56,7 +56,7 @@ class ModelSearchBase(ABC, BaseModel): on_model_found : Optional[Callable[[Path], bool]] = Field(default=None, description="Called when a model is found.") # noqa E221 on_search_completed : Optional[Callable[[Set[Path]], None]] = Field(default=None, description="Called when search is complete.") # noqa E221 stats : SearchStats = Field(default_factory=SearchStats, description="Summary statistics after search") # noqa E221 - logger : InvokeAILogger = Field(default=default_logger, description="Logger instance.") # noqa E221 + logger : Logger = Field(default=default_logger, description="Logger instance.") # noqa E221 # fmt: on class Config: @@ -128,13 +128,13 @@ def search_started(self) -> None: def model_found(self, model: Path) -> None: self.stats.models_found += 1 - if not self.on_model_found or self.on_model_found(model): + if self.on_model_found is None or self.on_model_found(model): self.stats.models_filtered += 1 self.models_found.add(model) def search_completed(self) -> None: - if self.on_search_completed: - self.on_search_completed(self._models_found) + if self.on_search_completed is not None: + self.on_search_completed(self.models_found) def search(self, directory: Union[Path, str]) -> Set[Path]: self._directory = Path(directory) From 94e8d1b6d57dfc860c99a7bf37c260a11af28045 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sat, 10 Feb 2024 18:09:45 -0500 Subject: [PATCH 082/100] make model manager v2 ready for PR review - Replace legacy model manager service with the v2 manager. - Update invocations to use new load interface. - Fixed many but not all type checking errors in the invocations. Most were unrelated to model manager - Updated routes. All the new routes live under the route tag `model_manager_v2`. To avoid confusion with the old routes, they have the URL prefix `/api/v2/models`. The old routes have been de-registered. - Added a pytest for the loader. - Updated documentation in contributing/MODEL_MANAGER.md --- docs/contributing/MODEL_MANAGER.md | 223 ++++++++++++------ invokeai/app/api/dependencies.py | 29 +-- .../{model_records.py => model_manager_v2.py} | 82 ++++--- invokeai/app/api/routers/models.py | 3 +- invokeai/app/api_app.py | 5 +- invokeai/app/invocations/compel.py | 43 ++-- invokeai/app/invocations/latent.py | 168 ++++++++----- invokeai/app/invocations/model.py | 2 +- invokeai/app/services/invocation_services.py | 6 - .../invocation_stats_default.py | 9 +- .../services/model_load/model_load_base.py | 60 ++++- .../services/model_load/model_load_default.py | 115 ++++++++- .../model_manager/model_manager_base.py | 38 ++- .../model_manager/model_manager_default.py | 39 ++- .../model_records/model_records_base.py | 49 ---- .../model_records/model_records_sql.py | 99 +------- .../sqlite_migrator/migrations/migration_6.py | 19 ++ invokeai/backend/embeddings/model_patcher.py | 4 +- invokeai/backend/image_util/safety_checker.py | 2 +- invokeai/backend/ip_adapter/ip_adapter.py | 4 +- invokeai/backend/model_manager/config.py | 13 +- .../backend/model_manager/load/load_base.py | 18 +- .../model_manager/load/load_default.py | 2 +- .../load/model_cache/model_cache_default.py | 2 +- .../load/model_loaders/controlnet.py | 12 +- .../load/model_loaders/generic_diffusers.py | 5 +- .../load/model_loaders/stable_diffusion.py | 8 +- .../model_manager/load/model_loaders/vae.py | 8 +- .../backend/model_manager/load/model_util.py | 2 +- invokeai/backend/model_manager/search.py | 3 +- .../stable_diffusion/schedulers/__init__.py | 2 + invokeai/frontend/install/model_install.py | 2 +- tests/aa_nodes/test_graph_execution_state.py | 2 - tests/aa_nodes/test_invoker.py | 2 - .../model_loading/test_model_load.py | 22 ++ .../model_manager_2_fixtures.py | 11 + 36 files changed, 679 insertions(+), 434 deletions(-) rename invokeai/app/api/routers/{model_records.py => model_manager_v2.py} (86%) create mode 100644 tests/backend/model_manager_2/model_loading/test_model_load.py diff --git a/docs/contributing/MODEL_MANAGER.md b/docs/contributing/MODEL_MANAGER.md index 880c8b24801..39220f4ba89 100644 --- a/docs/contributing/MODEL_MANAGER.md +++ b/docs/contributing/MODEL_MANAGER.md @@ -28,7 +28,7 @@ model. These are the: Hugging Face, as well as discriminating among model versions in Civitai, but can be used for arbitrary content. - * _ModelLoadServiceBase_ (**CURRENTLY UNDER DEVELOPMENT - NOT IMPLEMENTED**) + * _ModelLoadServiceBase_ Responsible for loading a model from disk into RAM and VRAM and getting it ready for inference. @@ -41,10 +41,10 @@ The four main services can be found in * `invokeai/app/services/model_records/` * `invokeai/app/services/model_install/` * `invokeai/app/services/downloads/` -* `invokeai/app/services/model_loader/` (**under development**) +* `invokeai/app/services/model_load/` Code related to the FastAPI web API can be found in -`invokeai/app/api/routers/model_records.py`. +`invokeai/app/api/routers/model_manager_v2.py`. *** @@ -84,10 +84,10 @@ diffusers model. When this happens, `original_hash` is unchanged, but `ModelType`, `ModelFormat` and `BaseModelType` are string enums that are defined in `invokeai.backend.model_manager.config`. They are also imported by, and can be reexported from, -`invokeai.app.services.model_record_service`: +`invokeai.app.services.model_manager.model_records`: ``` -from invokeai.app.services.model_record_service import ModelType, ModelFormat, BaseModelType +from invokeai.app.services.model_records import ModelType, ModelFormat, BaseModelType ``` The `path` field can be absolute or relative. If relative, it is taken @@ -123,7 +123,7 @@ taken to be the `models_dir` directory. `variant` is an enumerated string class with values `normal`, `inpaint` and `depth`. If needed, it can be imported if needed from -either `invokeai.app.services.model_record_service` or +either `invokeai.app.services.model_records` or `invokeai.backend.model_manager.config`. ### ONNXSD2Config @@ -134,7 +134,7 @@ either `invokeai.app.services.model_record_service` or | `upcast_attention` | bool | Model requires its attention module to be upcast | The `SchedulerPredictionType` enum can be imported from either -`invokeai.app.services.model_record_service` or +`invokeai.app.services.model_records` or `invokeai.backend.model_manager.config`. ### Other config classes @@ -157,15 +157,6 @@ indicates that the model is compatible with any of the base models. This works OK for some models, such as the IP Adapter image encoders, but is an all-or-nothing proposition. -Another issue is that the config class hierarchy is paralleled to some -extent by a `ModelBase` class hierarchy defined in -`invokeai.backend.model_manager.models.base` and its subclasses. These -are classes representing the models after they are loaded into RAM and -include runtime information such as load status and bytes used. Some -of the fields, including `name`, `model_type` and `base_model`, are -shared between `ModelConfigBase` and `ModelBase`, and this is a -potential source of confusion. - ## Reading and Writing Model Configuration Records The `ModelRecordService` provides the ability to retrieve model @@ -177,11 +168,11 @@ initialization and can be retrieved within an invocation from the `InvocationContext` object: ``` -store = context.services.model_record_store +store = context.services.model_manager.store ``` or from elsewhere in the code by accessing -`ApiDependencies.invoker.services.model_record_store`. +`ApiDependencies.invoker.services.model_manager.store`. ### Creating a `ModelRecordService` @@ -190,7 +181,7 @@ you can directly create either a `ModelRecordServiceSQL` or a `ModelRecordServiceFile` object: ``` -from invokeai.app.services.model_record_service import ModelRecordServiceSQL, ModelRecordServiceFile +from invokeai.app.services.model_records import ModelRecordServiceSQL, ModelRecordServiceFile store = ModelRecordServiceSQL.from_connection(connection, lock) store = ModelRecordServiceSQL.from_db_file('/path/to/sqlite_database.db') @@ -252,7 +243,7 @@ So a typical startup pattern would be: ``` import sqlite3 from invokeai.app.services.thread import lock -from invokeai.app.services.model_record_service import ModelRecordServiceBase +from invokeai.app.services.model_records import ModelRecordServiceBase from invokeai.app.services.config import InvokeAIAppConfig config = InvokeAIAppConfig.get_config() @@ -260,19 +251,6 @@ db_conn = sqlite3.connect(config.db_path.as_posix(), check_same_thread=False) store = ModelRecordServiceBase.open(config, db_conn, lock) ``` -_A note on simultaneous access to `invokeai.db`_: The current InvokeAI -service architecture for the image and graph databases is careful to -use a shared sqlite3 connection and a thread lock to ensure that two -threads don't attempt to access the database simultaneously. However, -the default `sqlite3` library used by Python reports using -**Serialized** mode, which allows multiple threads to access the -database simultaneously using multiple database connections (see -https://www.sqlite.org/threadsafe.html and -https://ricardoanderegg.com/posts/python-sqlite-thread-safety/). Therefore -it should be safe to allow the record service to open its own SQLite -database connection. Opening a model record service should then be as -simple as `ModelRecordServiceBase.open(config)`. - ### Fetching a Model's Configuration from `ModelRecordServiceBase` Configurations can be retrieved in several ways. @@ -1465,7 +1443,7 @@ create alternative instances if you wish. ### Creating a ModelLoadService object The class is defined in -`invokeai.app.services.model_loader_service`. It is initialized with +`invokeai.app.services.model_load`. It is initialized with an InvokeAIAppConfig object, from which it gets configuration information such as the user's desired GPU and precision, and with a previously-created `ModelRecordServiceBase` object, from which it @@ -1475,8 +1453,8 @@ Here is a typical initialization pattern: ``` from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.model_record_service import ModelRecordServiceBase -from invokeai.app.services.model_loader_service import ModelLoadService +from invokeai.app.services.model_records import ModelRecordServiceBase +from invokeai.app.services.model_load import ModelLoadService config = InvokeAIAppConfig.get_config() store = ModelRecordServiceBase.open(config) @@ -1487,14 +1465,11 @@ Note that we are relying on the contents of the application configuration to choose the implementation of `ModelRecordServiceBase`. -### get_model(key, [submodel_type], [context]) -> ModelInfo: +### load_model_by_key(key, [submodel_type], [context]) -> LoadedModel -*** TO DO: change to get_model(key, context=None, **kwargs) - -The `get_model()` method, like its similarly-named cousin in -`ModelRecordService`, receives the unique key that identifies the -model. It loads the model into memory, gets the model ready for use, -and returns a `ModelInfo` object. +The `load_model_by_key()` method receives the unique key that +identifies the model. It loads the model into memory, gets the model +ready for use, and returns a `LoadedModel` object. The optional second argument, `subtype` is a `SubModelType` string enum, such as "vae". It is mandatory when used with a main model, and @@ -1504,46 +1479,64 @@ The optional third argument, `context` can be provided by an invocation to trigger model load event reporting. See below for details. -The returned `ModelInfo` object shares some fields in common with -`ModelConfigBase`, but is otherwise a completely different beast: +The returned `LoadedModel` object contains a copy of the configuration +record returned by the model record `get_model()` method, as well as +the in-memory loaded model: -| **Field Name** | **Type** | **Description** | + +| **Attribute Name** | **Type** | **Description** | |----------------|-----------------|------------------| -| `key` | str | The model key derived from the ModelRecordService database | -| `name` | str | Name of this model | -| `base_model` | BaseModelType | Base model for this model | -| `type` | ModelType or SubModelType | Either the model type (non-main) or the submodel type (main models)| -| `location` | Path or str | Location of the model on the filesystem | -| `precision` | torch.dtype | The torch.precision to use for inference | -| `context` | ModelCache.ModelLocker | A context class used to lock the model in VRAM while in use | +| `config` | AnyModelConfig | A copy of the model's configuration record for retrieving base type, etc. | +| `model` | AnyModel | The instantiated model (details below) | +| `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM | + +Because the loader can return multiple model types, it is typed to +return `AnyModel`, a Union `ModelMixin`, `torch.nn.Module`, +`IAIOnnxRuntimeModel`, `IPAdapter`, `IPAdapterPlus`, and +`EmbeddingModelRaw`. `ModelMixin` is the base class of all diffusers +models, `EmbeddingModelRaw` is used for LoRA and TextualInversion +models. The others are obvious. -The types for `ModelInfo` and `SubModelType` can be imported from -`invokeai.app.services.model_loader_service`. -To use the model, you use the `ModelInfo` as a context manager using -the following pattern: +`LoadedModel` acts as a context manager. The context loads the model +into the execution device (e.g. VRAM on CUDA systems), locks the model +in the execution device for the duration of the context, and returns +the model. Use it like this: ``` -model_info = loader.get_model('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae')) +model_info = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae')) with model_info as vae: image = vae.decode(latents)[0] ``` -The `vae` model will stay locked in the GPU during the period of time -it is in the context manager's scope. +`get_model_by_key()` may raise any of the following exceptions: -`get_model()` may raise any of the following exceptions: - -- `UnknownModelException` -- key not in database -- `ModelNotFoundException` -- key in database but model not found at path -- `InvalidModelException` -- the model is guilty of a variety of sins +- `UnknownModelException` -- key not in database +- `ModelNotFoundException` -- key in database but model not found at path +- `NotImplementedException` -- the loader doesn't know how to load this type of model -** TO DO: ** Resolve discrepancy between ModelInfo.location and -ModelConfig.path. +### load_model_by_attr(model_name, base_model, model_type, [submodel], [context]) -> LoadedModel + +This is similar to `load_model_by_key`, but instead it accepts the +combination of the model's name, type and base, which it passes to the +model record config store for retrieval. If successful, this method +returns a `LoadedModel`. It can raise the following exceptions: + +``` +UnknownModelException -- model with these attributes not known +NotImplementedException -- the loader doesn't know how to load this type of model +ValueError -- more than one model matches this combination of base/type/name +``` + +### load_model_by_config(config, [submodel], [context]) -> LoadedModel + +This method takes an `AnyModelConfig` returned by +ModelRecordService.get_model() and returns the corresponding loaded +model. It may raise a `NotImplementedException`. ### Emitting model loading events -When the `context` argument is passed to `get_model()`, it will +When the `context` argument is passed to `load_model_*()`, it will retrieve the invocation event bus from the passed `InvocationContext` object to emit events on the invocation bus. The two events are "model_load_started" and "model_load_completed". Both carry the @@ -1563,3 +1556,97 @@ payload=dict( ) ``` +### Adding Model Loaders + +Model loaders are small classes that inherit from the `ModelLoader` +base class. They typically implement one method `_load_model()` whose +signature is: + +``` +def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, +) -> AnyModel: +``` + +`_load_model()` will be passed the path to the model on disk, an +optional repository variant (used by the diffusers loaders to select, +e.g. the `fp16` variant, and an optional submodel_type for main and +onnx models. + +To install a new loader, place it in +`invokeai/backend/model_manager/load/model_loaders`. Inherit from +`ModelLoader` and use the `@AnyModelLoader.register()` decorator to +indicate what type of models the loader can handle. + +Here is a complete example from `generic_diffusers.py`, which is able +to load several different diffusers types: + +``` +from pathlib import Path +from typing import Optional + +from invokeai.backend.model_manager import ( + AnyModel, + BaseModelType, + ModelFormat, + ModelRepoVariant, + ModelType, + SubModelType, +) +from ..load_base import AnyModelLoader +from ..load_default import ModelLoader + + +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers) +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers) +class GenericDiffusersLoader(ModelLoader): + """Class to load simple diffusers models.""" + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + model_class = self._get_hf_load_class(model_path) + if submodel_type is not None: + raise Exception(f"There are no submodels in models of type {model_class}") + variant = model_variant.value if model_variant else None + result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant) # type: ignore + return result +``` + +Note that a loader can register itself to handle several different +model types. An exception will be raised if more than one loader tries +to register the same model type. + +#### Conversion + +Some models require conversion to diffusers format before they can be +loaded. These loaders should override two additional methods: + +``` +_needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool +_convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path: +``` + +The first method accepts the model configuration, the path to where +the unmodified model is currently installed, and a proposed +destination for the converted model. This method returns True if the +model needs to be converted. It typically does this by comparing the +last modification time of the original model file to the modification +time of the converted model. In some cases you will also want to check +the modification date of the configuration record, in the event that +the user has changed something like the scheduler prediction type that +will require the model to be re-converted. See `controlnet.py` for an +example of this logic. + +The second method accepts the model configuration, the path to the +original model on disk, and the desired output path for the converted +model. It does whatever it needs to do to get the model into diffusers +format, and returns the Path of the resulting model. (The path should +ordinarily be the same as `output_path`.) + diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index dcb8d219971..378961a0557 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -8,9 +8,6 @@ from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache from invokeai.app.services.shared.sqlite.sqlite_util import init_db -from invokeai.backend.model_manager.load import AnyModelLoader, ModelConvertCache -from invokeai.backend.model_manager.load.model_cache import ModelCache -from invokeai.backend.model_manager.metadata import ModelMetadataStore from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData from invokeai.backend.util.logging import InvokeAILogger from invokeai.version.invokeai_version import __version__ @@ -30,9 +27,7 @@ from ..services.invocation_services import InvocationServices from ..services.invocation_stats.invocation_stats_default import InvocationStatsService from ..services.invoker import Invoker -from ..services.model_install import ModelInstallService from ..services.model_manager.model_manager_default import ModelManagerService -from ..services.model_records import ModelRecordServiceSQL from ..services.names.names_default import SimpleNameService from ..services.session_processor.session_processor_default import DefaultSessionProcessor from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue @@ -98,28 +93,10 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger conditioning = ObjectSerializerForwardCache( ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True) ) - model_manager = ModelManagerService(config, logger) - model_record_service = ModelRecordServiceSQL(db=db) - model_loader = AnyModelLoader( - app_config=config, - logger=logger, - ram_cache=ModelCache( - max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger - ), - convert_cache=ModelConvertCache( - cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size - ), - ) - model_record_service = ModelRecordServiceSQL(db=db, loader=model_loader) download_queue_service = DownloadQueueService(event_bus=events) - model_install_service = ModelInstallService( - app_config=config, - record_store=model_record_service, - download_queue=download_queue_service, - metadata_store=ModelMetadataStore(db=db), - event_bus=events, + model_manager = ModelManagerService.build_model_manager( + app_config=configuration, db=db, download_queue=download_queue_service, events=events ) - model_manager = ModelManagerService(config, logger) # TO DO: legacy model manager v1. Remove names = SimpleNameService() performance_statistics = InvocationStatsService() processor = DefaultInvocationProcessor() @@ -143,9 +120,7 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger invocation_cache=invocation_cache, logger=logger, model_manager=model_manager, - model_records=model_record_service, download_queue=download_queue_service, - model_install=model_install_service, names=names, performance_statistics=performance_statistics, processor=processor, diff --git a/invokeai/app/api/routers/model_records.py b/invokeai/app/api/routers/model_manager_v2.py similarity index 86% rename from invokeai/app/api/routers/model_records.py rename to invokeai/app/api/routers/model_manager_v2.py index f9a3e408985..4fc785e4f7a 100644 --- a/invokeai/app/api/routers/model_records.py +++ b/invokeai/app/api/routers/model_manager_v2.py @@ -32,7 +32,7 @@ from ..dependencies import ApiDependencies -model_records_router = APIRouter(prefix="/v1/model/record", tags=["model_manager_v2_unstable"]) +model_manager_v2_router = APIRouter(prefix="/v2/models", tags=["model_manager_v2"]) class ModelsList(BaseModel): @@ -52,7 +52,7 @@ class ModelTagSet(BaseModel): tags: Set[str] -@model_records_router.get( +@model_manager_v2_router.get( "/", operation_id="list_model_records", ) @@ -65,7 +65,7 @@ async def list_model_records( ), ) -> ModelsList: """Get a list of models.""" - record_store = ApiDependencies.invoker.services.model_records + record_store = ApiDependencies.invoker.services.model_manager.store found_models: list[AnyModelConfig] = [] if base_models: for base_model in base_models: @@ -81,7 +81,7 @@ async def list_model_records( return ModelsList(models=found_models) -@model_records_router.get( +@model_manager_v2_router.get( "/i/{key}", operation_id="get_model_record", responses={ @@ -94,24 +94,27 @@ async def get_model_record( key: str = Path(description="Key of the model record to fetch."), ) -> AnyModelConfig: """Get a model record""" - record_store = ApiDependencies.invoker.services.model_records + record_store = ApiDependencies.invoker.services.model_manager.store try: - return record_store.get_model(key) + config: AnyModelConfig = record_store.get_model(key) + return config except UnknownModelException as e: raise HTTPException(status_code=404, detail=str(e)) -@model_records_router.get("/meta", operation_id="list_model_summary") +@model_manager_v2_router.get("/meta", operation_id="list_model_summary") async def list_model_summary( page: int = Query(default=0, description="The page to get"), per_page: int = Query(default=10, description="The number of models per page"), order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"), ) -> PaginatedResults[ModelSummary]: """Gets a page of model summary data.""" - return ApiDependencies.invoker.services.model_records.list_models(page=page, per_page=per_page, order_by=order_by) + record_store = ApiDependencies.invoker.services.model_manager.store + results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by) + return results -@model_records_router.get( +@model_manager_v2_router.get( "/meta/i/{key}", operation_id="get_model_metadata", responses={ @@ -124,24 +127,25 @@ async def get_model_metadata( key: str = Path(description="Key of the model repo metadata to fetch."), ) -> Optional[AnyModelRepoMetadata]: """Get a model metadata object.""" - record_store = ApiDependencies.invoker.services.model_records - result = record_store.get_metadata(key) + record_store = ApiDependencies.invoker.services.model_manager.store + result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key) if not result: raise HTTPException(status_code=404, detail="No metadata for a model with this key") return result -@model_records_router.get( +@model_manager_v2_router.get( "/tags", operation_id="list_tags", ) async def list_tags() -> Set[str]: """Get a unique set of all the model tags.""" - record_store = ApiDependencies.invoker.services.model_records - return record_store.list_tags() + record_store = ApiDependencies.invoker.services.model_manager.store + result: Set[str] = record_store.list_tags() + return result -@model_records_router.get( +@model_manager_v2_router.get( "/tags/search", operation_id="search_by_metadata_tags", ) @@ -149,12 +153,12 @@ async def search_by_metadata_tags( tags: Set[str] = Query(default=None, description="Tags to search for"), ) -> ModelsList: """Get a list of models.""" - record_store = ApiDependencies.invoker.services.model_records + record_store = ApiDependencies.invoker.services.model_manager.store results = record_store.search_by_metadata_tag(tags) return ModelsList(models=results) -@model_records_router.patch( +@model_manager_v2_router.patch( "/i/{key}", operation_id="update_model_record", responses={ @@ -172,9 +176,9 @@ async def update_model_record( ) -> AnyModelConfig: """Update model contents with a new config. If the model name or base fields are changed, then the model is renamed.""" logger = ApiDependencies.invoker.services.logger - record_store = ApiDependencies.invoker.services.model_records + record_store = ApiDependencies.invoker.services.model_manager.store try: - model_response = record_store.update_model(key, config=info) + model_response: AnyModelConfig = record_store.update_model(key, config=info) logger.info(f"Updated model: {key}") except UnknownModelException as e: raise HTTPException(status_code=404, detail=str(e)) @@ -184,7 +188,7 @@ async def update_model_record( return model_response -@model_records_router.delete( +@model_manager_v2_router.delete( "/i/{key}", operation_id="del_model_record", responses={ @@ -205,7 +209,7 @@ async def del_model_record( logger = ApiDependencies.invoker.services.logger try: - installer = ApiDependencies.invoker.services.model_install + installer = ApiDependencies.invoker.services.model_manager.install installer.delete(key) logger.info(f"Deleted model: {key}") return Response(status_code=204) @@ -214,7 +218,7 @@ async def del_model_record( raise HTTPException(status_code=404, detail=str(e)) -@model_records_router.post( +@model_manager_v2_router.post( "/i/", operation_id="add_model_record", responses={ @@ -229,7 +233,7 @@ async def add_model_record( ) -> AnyModelConfig: """Add a model using the configuration information appropriate for its type.""" logger = ApiDependencies.invoker.services.logger - record_store = ApiDependencies.invoker.services.model_records + record_store = ApiDependencies.invoker.services.model_manager.store if config.key == "": config.key = sha1(randbytes(100)).hexdigest() logger.info(f"Created model {config.key} for {config.name}") @@ -243,10 +247,11 @@ async def add_model_record( raise HTTPException(status_code=415) # now fetch it out - return record_store.get_model(config.key) + result: AnyModelConfig = record_store.get_model(config.key) + return result -@model_records_router.post( +@model_manager_v2_router.post( "/import", operation_id="import_model_record", responses={ @@ -322,7 +327,7 @@ async def import_model( logger = ApiDependencies.invoker.services.logger try: - installer = ApiDependencies.invoker.services.model_install + installer = ApiDependencies.invoker.services.model_manager.install result: ModelInstallJob = installer.import_model( source=source, config=config, @@ -340,17 +345,17 @@ async def import_model( return result -@model_records_router.get( +@model_manager_v2_router.get( "/import", operation_id="list_model_install_jobs", ) async def list_model_install_jobs() -> List[ModelInstallJob]: """Return list of model install jobs.""" - jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_install.list_jobs() + jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_manager.install.list_jobs() return jobs -@model_records_router.get( +@model_manager_v2_router.get( "/import/{id}", operation_id="get_model_install_job", responses={ @@ -361,12 +366,13 @@ async def list_model_install_jobs() -> List[ModelInstallJob]: async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob: """Return model install job corresponding to the given source.""" try: - return ApiDependencies.invoker.services.model_install.get_job_by_id(id) + result: ModelInstallJob = ApiDependencies.invoker.services.model_manager.install.get_job_by_id(id) + return result except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) -@model_records_router.delete( +@model_manager_v2_router.delete( "/import/{id}", operation_id="cancel_model_install_job", responses={ @@ -377,7 +383,7 @@ async def get_model_install_job(id: int = Path(description="Model install id")) ) async def cancel_model_install_job(id: int = Path(description="Model install job ID")) -> None: """Cancel the model install job(s) corresponding to the given job ID.""" - installer = ApiDependencies.invoker.services.model_install + installer = ApiDependencies.invoker.services.model_manager.install try: job = installer.get_job_by_id(id) except ValueError as e: @@ -385,7 +391,7 @@ async def cancel_model_install_job(id: int = Path(description="Model install job installer.cancel_job(job) -@model_records_router.patch( +@model_manager_v2_router.patch( "/import", operation_id="prune_model_install_jobs", responses={ @@ -395,11 +401,11 @@ async def cancel_model_install_job(id: int = Path(description="Model install job ) async def prune_model_install_jobs() -> Response: """Prune all completed and errored jobs from the install job list.""" - ApiDependencies.invoker.services.model_install.prune_jobs() + ApiDependencies.invoker.services.model_manager.install.prune_jobs() return Response(status_code=204) -@model_records_router.patch( +@model_manager_v2_router.patch( "/sync", operation_id="sync_models_to_config", responses={ @@ -414,11 +420,11 @@ async def sync_models_to_config() -> Response: Model files without a corresponding record in the database are added. Orphan records without a models file are deleted. """ - ApiDependencies.invoker.services.model_install.sync_to_config() + ApiDependencies.invoker.services.model_manager.install.sync_to_config() return Response(status_code=204) -@model_records_router.put( +@model_manager_v2_router.put( "/merge", operation_id="merge", ) @@ -451,7 +457,7 @@ async def merge( try: logger.info(f"Merging models: {keys} into {merge_dest_directory or ''}/{merged_model_name}") dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None - installer = ApiDependencies.invoker.services.model_install + installer = ApiDependencies.invoker.services.model_manager.install merger = ModelMerger(installer) model_names = [installer.record_store.get_model(x).name for x in keys] response = merger.merge_diffusion_models_and_save( diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 8f83820cf89..0aa7aa0ecba 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -8,8 +8,7 @@ from pydantic import BaseModel, ConfigDict, Field, TypeAdapter from starlette.exceptions import HTTPException -from invokeai.backend import BaseModelType, ModelType -from invokeai.backend.model_management import MergeInterpolationMethod +from invokeai.backend.model_management import BaseModelType, MergeInterpolationMethod, ModelType from invokeai.backend.model_management.models import ( OPENAPI_MODEL_CONFIGS, InvalidModelException, diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index f48074de7c7..851cbc8160e 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -48,7 +48,7 @@ boards, download_queue, images, - model_records, + model_manager_v2, models, session_queue, sessions, @@ -114,8 +114,7 @@ async def shutdown_event() -> None: app.include_router(sessions.session_router, prefix="/api") app.include_router(utilities.utilities_router, prefix="/api") -app.include_router(models.models_router, prefix="/api") -app.include_router(model_records.model_records_router, prefix="/api") +app.include_router(model_manager_v2.model_manager_v2_router, prefix="/api") app.include_router(download_queue.download_queue_router, prefix="/api") app.include_router(images.images_router, prefix="/api") app.include_router(boards.boards_router, prefix="/api") diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 0e1a6bdc6fb..3850fb6cc3d 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -3,6 +3,7 @@ import torch from compel import Compel, ReturnedEmbeddingsType from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment +from transformers import CLIPTokenizer import invokeai.backend.util.logging as logger from invokeai.app.invocations.fields import ( @@ -68,18 +69,18 @@ class CompelInvocation(BaseInvocation): @torch.no_grad() def invoke(self, context: InvocationContext) -> ConditioningOutput: - tokenizer_info = context.services.model_records.load_model( + tokenizer_info = context.services.model_manager.load.load_model_by_key( **self.clip.tokenizer.model_dump(), context=context, ) - text_encoder_info = context.services.model_records.load_model( + text_encoder_info = context.services.model_manager.load.load_model_by_key( **self.clip.text_encoder.model_dump(), context=context, ) def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in self.clip.loras: - lora_info = context.services.model_records.load_model( + lora_info = context.services.model_manager.load.load_model_by_key( **lora.model_dump(exclude={"weight"}), context=context ) assert isinstance(lora_info.model, LoRAModelRaw) @@ -93,7 +94,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for trigger in extract_ti_triggers_from_prompt(self.prompt): name = trigger[1:-1] try: - loaded_model = context.services.model_records.load_model( + loaded_model = context.services.model_manager.load.load_model_by_key( **self.clip.text_encoder.model_dump(), context=context, ).model @@ -164,11 +165,11 @@ def run_clip_compel( lora_prefix: str, zero_on_empty: bool, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]: - tokenizer_info = context.services.model_records.load_model( + tokenizer_info = context.services.model_manager.load.load_model_by_key( **clip_field.tokenizer.model_dump(), context=context, ) - text_encoder_info = context.services.model_records.load_model( + text_encoder_info = context.services.model_manager.load.load_model_by_key( **clip_field.text_encoder.model_dump(), context=context, ) @@ -196,7 +197,7 @@ def run_clip_compel( def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in clip_field.loras: - lora_info = context.services.model_records.load_model( + lora_info = context.services.model_manager.load.load_model_by_key( **lora.model_dump(exclude={"weight"}), context=context ) lora_model = lora_info.model @@ -211,7 +212,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for trigger in extract_ti_triggers_from_prompt(prompt): name = trigger[1:-1] try: - ti_model = context.services.model_records.load_model_by_attr( + ti_model = context.services.model_manager.load.load_model_by_attr( model_name=name, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion, @@ -448,9 +449,9 @@ def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput: def get_max_token_count( - tokenizer, + tokenizer: CLIPTokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], - truncate_if_too_long=False, + truncate_if_too_long: bool = False, ) -> int: if type(prompt) is Blend: blend: Blend = prompt @@ -462,7 +463,9 @@ def get_max_token_count( return len(get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long)) -def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> List[str]: +def get_tokens_for_prompt_object( + tokenizer: CLIPTokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long: bool = True +) -> List[str]: if type(parsed_prompt) is Blend: raise ValueError("Blend is not supported here - you need to get tokens for each of its .children") @@ -475,24 +478,29 @@ def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, trun for x in parsed_prompt.children ] text = " ".join(text_fragments) - tokens = tokenizer.tokenize(text) + tokens: List[str] = tokenizer.tokenize(text) if truncate_if_too_long: max_tokens_length = tokenizer.model_max_length - 2 # typically 75 tokens = tokens[0:max_tokens_length] return tokens -def log_tokenization_for_conjunction(c: Conjunction, tokenizer, display_label_prefix=None): +def log_tokenization_for_conjunction( + c: Conjunction, tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None +) -> None: display_label_prefix = display_label_prefix or "" for i, p in enumerate(c.prompts): if len(c.prompts) > 1: this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})" else: + assert display_label_prefix is not None this_display_label_prefix = display_label_prefix log_tokenization_for_prompt_object(p, tokenizer, display_label_prefix=this_display_label_prefix) -def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None): +def log_tokenization_for_prompt_object( + p: Union[Blend, FlattenedPrompt], tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None +) -> None: display_label_prefix = display_label_prefix or "" if type(p) is Blend: blend: Blend = p @@ -532,7 +540,12 @@ def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokeniz log_tokenization_for_text(text, tokenizer, display_label=display_label_prefix) -def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False): +def log_tokenization_for_text( + text: str, + tokenizer: CLIPTokenizer, + display_label: Optional[str] = None, + truncate_if_too_long: Optional[bool] = False, +) -> None: """shows how the prompt is tokenized # usually tokens have '' to indicate end-of-word, # but for readability it has been replaced with ' ' diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 063b23fa589..289da2dd73d 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -3,13 +3,15 @@ import math from contextlib import ExitStack from functools import singledispatchmethod -from typing import Iterator, List, Literal, Optional, Tuple, Union +from typing import Any, Iterator, List, Literal, Optional, Tuple, Union import einops import numpy as np +import numpy.typing as npt import torch import torchvision.transforms as T -from diffusers import AutoencoderKL, AutoencoderTiny, UNet2DConditionModel +from diffusers import AutoencoderKL, AutoencoderTiny +from diffusers.configuration_utils import ConfigMixin from diffusers.image_processor import VaeImageProcessor from diffusers.models.adapter import T2IAdapter from diffusers.models.attention_processor import ( @@ -18,8 +20,10 @@ LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.schedulers import DPMSolverSDEScheduler from diffusers.schedulers import SchedulerMixin as Scheduler +from PIL import Image from pydantic import field_validator from torchvision.transforms.functional import resize as tv_resize @@ -46,9 +50,10 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.controlnet_utils import prepare_control_image +from invokeai.backend.embeddings.lora import LoRAModelRaw from invokeai.backend.embeddings.model_patcher import ModelPatcher from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus -from invokeai.backend.model_manager import AnyModel, BaseModelType +from invokeai.backend.model_manager import BaseModelType, LoadedModel from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo from invokeai.backend.util.silence_warnings import SilenceWarnings @@ -123,10 +128,10 @@ class CreateDenoiseMaskInvocation(BaseInvocation): ui_order=4, ) - def prep_mask_tensor(self, mask_image): + def prep_mask_tensor(self, mask_image: Image) -> torch.Tensor: if mask_image.mode != "L": mask_image = mask_image.convert("L") - mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False) + mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False) if mask_tensor.dim() == 3: mask_tensor = mask_tensor.unsqueeze(0) # if shape is not None: @@ -136,25 +141,25 @@ def prep_mask_tensor(self, mask_image): @torch.no_grad() def invoke(self, context: InvocationContext) -> DenoiseMaskOutput: if self.image is not None: - image = context.images.get_pil(self.image.image_name) - image = image_resized_to_grid_as_tensor(image.convert("RGB")) - if image.dim() == 3: - image = image.unsqueeze(0) + image = context.services.images.get_pil_image(self.image.image_name) + image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) + if image_tensor.dim() == 3: + image_tensor = image_tensor.unsqueeze(0) else: - image = None + image_tensor = None mask = self.prep_mask_tensor( context.images.get_pil(self.mask.image_name), ) - if image is not None: - vae_info = context.services.model_records.load_model( + if image_tensor is not None: + vae_info = context.services.model_manager.load.load_model_by_key( **self.vae.vae.model_dump(), context=context, ) - img_mask = tv_resize(mask, image.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) - masked_image = image * torch.where(img_mask < 0.5, 0.0, 1.0) + img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) + masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0) # TODO: masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone()) @@ -177,7 +182,7 @@ def get_scheduler( seed: int, ) -> Scheduler: scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"]) - orig_scheduler_info = context.services.model_records.load_model( + orig_scheduler_info = context.services.model_manager.load.load_model_by_key( **scheduler_info.model_dump(), context=context, ) @@ -188,7 +193,7 @@ def get_scheduler( scheduler_config = scheduler_config["_backup"] scheduler_config = { **scheduler_config, - **scheduler_extra_config, + **scheduler_extra_config, # FIXME "_backup": scheduler_config, } @@ -201,6 +206,7 @@ def get_scheduler( # hack copied over from generate.py if not hasattr(scheduler, "uses_inpainting_model"): scheduler.uses_inpainting_model = lambda: False + assert isinstance(scheduler, Scheduler) return scheduler @@ -284,7 +290,7 @@ class DenoiseLatentsInvocation(BaseInvocation): ) @field_validator("cfg_scale") - def ge_one(cls, v): + def ge_one(cls, v: Union[List[float], float]) -> Union[List[float], float]: """validate that all cfg_scale values are >= 1""" if isinstance(v, list): for i in v: @@ -298,9 +304,9 @@ def ge_one(cls, v): def get_conditioning_data( self, context: InvocationContext, - scheduler, - unet, - seed, + scheduler: Scheduler, + unet: UNet2DConditionModel, + seed: int, ) -> ConditioningData: positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name) c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) @@ -323,7 +329,7 @@ def get_conditioning_data( ), ) - conditioning_data = conditioning_data.add_scheduler_args_if_applicable( + conditioning_data = conditioning_data.add_scheduler_args_if_applicable( # FIXME scheduler, # for ddim scheduler eta=0.0, # ddim_eta @@ -335,8 +341,8 @@ def get_conditioning_data( def create_pipeline( self, - unet, - scheduler, + unet: UNet2DConditionModel, + scheduler: Scheduler, ) -> StableDiffusionGeneratorPipeline: # TODO: # configure_model_padding( @@ -347,10 +353,10 @@ def create_pipeline( class FakeVae: class FakeVaeConfig: - def __init__(self): + def __init__(self) -> None: self.block_out_channels = [0] - def __init__(self): + def __init__(self) -> None: self.config = FakeVae.FakeVaeConfig() return StableDiffusionGeneratorPipeline( @@ -367,11 +373,11 @@ def __init__(self): def prep_control_data( self, context: InvocationContext, - control_input: Union[ControlField, List[ControlField]], + control_input: Optional[Union[ControlField, List[ControlField]]], latents_shape: List[int], exit_stack: ExitStack, do_classifier_free_guidance: bool = True, - ) -> List[ControlNetData]: + ) -> Optional[List[ControlNetData]]: # Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR. control_height_resize = latents_shape[2] * LATENT_SCALE_FACTOR control_width_resize = latents_shape[3] * LATENT_SCALE_FACTOR @@ -394,7 +400,7 @@ def prep_control_data( controlnet_data = [] for control_info in control_list: control_model = exit_stack.enter_context( - context.services.model_records.load_model( + context.services.model_manager.load.load_model_by_key( key=control_info.control_model.key, context=context, ) @@ -460,23 +466,25 @@ def prep_ip_adapter_data( conditioning_data.ip_adapter_conditioning = [] for single_ip_adapter in ip_adapter: ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context( - context.services.model_records.load_model( + context.services.model_manager.load.load_model_by_key( key=single_ip_adapter.ip_adapter_model.key, context=context, ) ) - image_encoder_model_info = context.services.model_records.load_model( + image_encoder_model_info = context.services.model_manager.load.load_model_by_key( key=single_ip_adapter.image_encoder_model.key, context=context, ) # `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here. - single_ipa_images = single_ip_adapter.image - if not isinstance(single_ipa_images, list): - single_ipa_images = [single_ipa_images] + single_ipa_image_fields = single_ip_adapter.image + if not isinstance(single_ipa_image_fields, list): + single_ipa_image_fields = [single_ipa_image_fields] - single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_images] + single_ipa_images = [ + context.services.images.get_pil_image(image.image_name) for image in single_ipa_image_fields + ] # TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other # models are needed in memory. This would help to reduce peak memory utilization in low-memory environments. @@ -520,21 +528,19 @@ def run_t2i_adapters( t2i_adapter_data = [] for t2i_adapter_field in t2i_adapter: - t2i_adapter_model_info = context.services.model_records.load_model( + t2i_adapter_model_info = context.services.model_manager.load.load_model_by_key( key=t2i_adapter_field.t2i_adapter_model.key, context=context, ) image = context.images.get_pil(t2i_adapter_field.image.image_name) # The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally. - if t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusion1: + if t2i_adapter_model_info.base == BaseModelType.StableDiffusion1: max_unet_downscale = 8 - elif t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusionXL: + elif t2i_adapter_model_info.base == BaseModelType.StableDiffusionXL: max_unet_downscale = 4 else: - raise ValueError( - f"Unexpected T2I-Adapter base model type: '{t2i_adapter_field.t2i_adapter_model.base_model}'." - ) + raise ValueError(f"Unexpected T2I-Adapter base model type: '{t2i_adapter_model_info.base}'.") t2i_adapter_model: T2IAdapter with t2i_adapter_model_info as t2i_adapter_model: @@ -582,7 +588,15 @@ def run_t2i_adapters( # original idea by https://github.com/AmericanPresidentJimmyCarter # TODO: research more for second order schedulers timesteps - def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_end): + def init_scheduler( + self, + scheduler: Union[Scheduler, ConfigMixin], + device: torch.device, + steps: int, + denoising_start: float, + denoising_end: float, + ) -> Tuple[int, List[int], int]: + assert isinstance(scheduler, ConfigMixin) if scheduler.config.get("cpu_only", False): scheduler.set_timesteps(steps, device="cpu") timesteps = scheduler.timesteps.to(device=device) @@ -594,11 +608,11 @@ def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_en _timesteps = timesteps[:: scheduler.order] # get start timestep index - t_start_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_start))) + t_start_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_start))) t_start_idx = len(list(filter(lambda ts: ts >= t_start_val, _timesteps))) # get end timestep index - t_end_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_end))) + t_end_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_end))) t_end_idx = len(list(filter(lambda ts: ts >= t_end_val, _timesteps[t_start_idx:]))) # apply order to indexes @@ -611,7 +625,9 @@ def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_en return num_inference_steps, timesteps, init_timestep - def prep_inpaint_mask(self, context: InvocationContext, latents): + def prep_inpaint_mask( + self, context: InvocationContext, latents: torch.Tensor + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: if self.denoise_mask is None: return None, None @@ -660,12 +676,19 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: do_classifier_free_guidance=True, ) - def step_callback(state: PipelineIntermediateState): - context.util.sd_step_callback(state, self.unet.unet.base_model) + # Get the source node id (we are invoking the prepared node) + graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) + source_node_id = graph_execution_state.prepared_source_mapping[self.id] + + # get the unet's config so that we can pass the base to dispatch_progress() + unet_config = context.services.model_manager.store.get_model(**self.unet.unet.model_dump()) - def _lora_loader() -> Iterator[Tuple[AnyModel, float]]: + def step_callback(state: PipelineIntermediateState) -> None: + self.dispatch_progress(context, source_node_id, state, unet_config.base) + + def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in self.unet.loras: - lora_info = context.services.model_records.load_model( + lora_info = context.services.model_manager.load.load_model_by_key( **lora.model_dump(exclude={"weight"}), context=context, ) @@ -673,7 +696,7 @@ def _lora_loader() -> Iterator[Tuple[AnyModel, float]]: del lora_info return - unet_info = context.services.model_records.load_model( + unet_info = context.services.model_manager.load.load_model_by_key( **self.unet.unet.model_dump(), context=context, ) @@ -783,7 +806,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): def invoke(self, context: InvocationContext) -> ImageOutput: latents = context.tensors.load(self.latents.latents_name) - vae_info = context.services.model_records.load_model( + vae_info = context.services.model_manager.load.load_model_by_key( **self.vae.vae.model_dump(), context=context, ) @@ -961,8 +984,9 @@ class ImageToLatentsInvocation(BaseInvocation): fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32) @staticmethod - def vae_encode(vae_info, upcast, tiled, image_tensor): + def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor: with vae_info as vae: + assert isinstance(vae, torch.nn.Module) orig_dtype = vae.dtype if upcast: vae.to(dtype=torch.float32) @@ -1008,7 +1032,7 @@ def vae_encode(vae_info, upcast, tiled, image_tensor): def invoke(self, context: InvocationContext) -> LatentsOutput: image = context.images.get_pil(self.image.image_name) - vae_info = context.services.model_records.load_model( + vae_info = context.services.model_manager.load.load_model_by_key( **self.vae.vae.model_dump(), context=context, ) @@ -1026,14 +1050,19 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: @singledispatchmethod @staticmethod def _encode_to_tensor(vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor: + assert isinstance(vae, torch.nn.Module) image_tensor_dist = vae.encode(image_tensor).latent_dist - latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible! + latents: torch.Tensor = image_tensor_dist.sample().to( + dtype=vae.dtype + ) # FIXME: uses torch.randn. make reproducible! return latents @_encode_to_tensor.register @staticmethod def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor: - return vae.encode(image_tensor).latents + assert isinstance(vae, torch.nn.Module) + latents: torch.FloatTensor = vae.encode(image_tensor).latents + return latents @invocation( @@ -1066,7 +1095,12 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: # TODO: device = choose_torch_device() - def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): + def slerp( + t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here? + v0: Union[torch.Tensor, npt.NDArray[Any]], + v1: Union[torch.Tensor, npt.NDArray[Any]], + DOT_THRESHOLD: float = 0.9995, + ) -> Union[torch.Tensor, npt.NDArray[Any]]: """ Spherical linear interpolation Args: @@ -1099,12 +1133,16 @@ def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): v2 = s0 * v0 + s1 * v1 if inputs_are_torch: - v2 = torch.from_numpy(v2).to(device) - - return v2 + v2_torch: torch.Tensor = torch.from_numpy(v2).to(device) + return v2_torch + else: + assert isinstance(v2, np.ndarray) + return v2 # blend - blended_latents = slerp(self.alpha, latents_a, latents_b) + bl = slerp(self.alpha, latents_a, latents_b) + assert isinstance(bl, torch.Tensor) + blended_latents: torch.Tensor = bl # for type checking convenience # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 blended_latents = blended_latents.to("cpu") @@ -1197,15 +1235,19 @@ class IdealSizeInvocation(BaseInvocation): description="Amount to multiply the model's dimensions by when calculating the ideal size (may result in initial generation artifacts if too large)", ) - def trim_to_multiple_of(self, *args, multiple_of=LATENT_SCALE_FACTOR): + def trim_to_multiple_of(self, *args: int, multiple_of: int = LATENT_SCALE_FACTOR) -> Tuple[int, ...]: return tuple((x - x % multiple_of) for x in args) def invoke(self, context: InvocationContext) -> IdealSizeOutput: + unet_config = context.services.model_manager.load.load_model_by_key( + **self.unet.unet.model_dump(), + context=context, + ) aspect = self.width / self.height - dimension = 512 - if self.unet.unet.base_model == BaseModelType.StableDiffusion2: + dimension: float = 512 + if unet_config.base == BaseModelType.StableDiffusion2: dimension = 768 - elif self.unet.unet.base_model == BaseModelType.StableDiffusionXL: + elif unet_config.base == BaseModelType.StableDiffusionXL: dimension = 1024 dimension = dimension * self.multiplier min_dimension = math.floor(dimension * 0.5) diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index e2ea7442839..fa6e8b98da0 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -17,7 +17,7 @@ class ModelInfo(BaseModel): - key: str = Field(description="Info to load submodel") + key: str = Field(description="Key of model as returned by ModelRecordServiceBase.get_model()") submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel") diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index e893be87636..0a1fa1e9222 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -27,9 +27,7 @@ from .invocation_queue.invocation_queue_base import InvocationQueueABC from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase from .item_storage.item_storage_base import ItemStorageABC - from .model_install import ModelInstallServiceBase from .model_manager.model_manager_base import ModelManagerServiceBase - from .model_records import ModelRecordServiceBase from .names.names_base import NameServiceBase from .session_processor.session_processor_base import SessionProcessorBase from .session_queue.session_queue_base import SessionQueueBase @@ -55,9 +53,7 @@ def __init__( image_records: "ImageRecordStorageBase", logger: "Logger", model_manager: "ModelManagerServiceBase", - model_records: "ModelRecordServiceBase", download_queue: "DownloadQueueServiceBase", - model_install: "ModelInstallServiceBase", processor: "InvocationProcessorABC", performance_statistics: "InvocationStatsServiceBase", queue: "InvocationQueueABC", @@ -82,9 +78,7 @@ def __init__( self.image_records = image_records self.logger = logger self.model_manager = model_manager - self.model_records = model_records self.download_queue = download_queue - self.model_install = model_install self.processor = processor self.performance_statistics = performance_statistics self.queue = queue diff --git a/invokeai/app/services/invocation_stats/invocation_stats_default.py b/invokeai/app/services/invocation_stats/invocation_stats_default.py index 0c63b545ff2..6c893021de4 100644 --- a/invokeai/app/services/invocation_stats/invocation_stats_default.py +++ b/invokeai/app/services/invocation_stats/invocation_stats_default.py @@ -43,8 +43,10 @@ def start(self, invoker: Invoker) -> None: @contextmanager def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str) -> Iterator[None]: + # This is to handle case of the model manager not being initialized, which happens + # during some tests. services = self._invoker.services - if services.model_records is None or services.model_records.loader is None: + if services.model_manager is None or services.model_manager.load is None: yield None if not self._stats.get(graph_execution_state_id): # First time we're seeing this graph_execution_state_id. @@ -60,9 +62,8 @@ def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: st if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats() - # TO DO [LS]: clean up loader service - shouldn't be an attribute of model records - assert services.model_records.loader is not None - services.model_records.loader.ram_cache.stats = self._cache_stats[graph_execution_state_id] + assert services.model_manager.load is not None + services.model_manager.load.ram_cache.stats = self._cache_stats[graph_execution_state_id] try: # Let the invocation run. diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py index 7228806e809..f298d98ce6d 100644 --- a/invokeai/app/services/model_load/model_load_base.py +++ b/invokeai/app/services/model_load/model_load_base.py @@ -4,7 +4,8 @@ from abc import ABC, abstractmethod from typing import Optional -from invokeai.backend.model_manager import AnyModelConfig, SubModelType +from invokeai.app.invocations.baseinvocation import InvocationContext +from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelType, SubModelType from invokeai.backend.model_manager.load import LoadedModel @@ -12,11 +13,60 @@ class ModelLoadServiceBase(ABC): """Wrapper around AnyModelLoader.""" @abstractmethod - def load_model_by_key(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel: - """Given a model's key, load it and return the LoadedModel object.""" + def load_model_by_key( + self, + key: str, + submodel_type: Optional[SubModelType] = None, + context: Optional[InvocationContext] = None, + ) -> LoadedModel: + """ + Given a model's key, load it and return the LoadedModel object. + + :param key: Key of model config to be fetched. + :param submodel: For main (pipeline models), the submodel to fetch. + :param context: Invocation context used for event reporting + """ pass @abstractmethod - def load_model_by_config(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: - """Given a model's configuration, load it and return the LoadedModel object.""" + def load_model_by_config( + self, + model_config: AnyModelConfig, + submodel_type: Optional[SubModelType] = None, + context: Optional[InvocationContext] = None, + ) -> LoadedModel: + """ + Given a model's configuration, load it and return the LoadedModel object. + + :param model_config: Model configuration record (as returned by ModelRecordBase.get_model()) + :param submodel: For main (pipeline models), the submodel to fetch. + :param context: Invocation context used for event reporting + """ pass + + @abstractmethod + def load_model_by_attr( + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + submodel: Optional[SubModelType] = None, + context: Optional[InvocationContext] = None, + ) -> LoadedModel: + """ + Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object. + + This is provided for API compatability with the get_model() method + in the original model manager. However, note that LoadedModel is + not the same as the original ModelInfo that ws returned. + + :param model_name: Name of to be fetched. + :param base_model: Base model + :param model_type: Type of the model + :param submodel: For main (pipeline models), the submodel to fetch + :param context: The invocation context. + + Exceptions: UnknownModelException -- model with these attributes not known + NotImplementedException -- a model loader was not provided at initialization time + ValueError -- more than one model matches this combination + """ diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index 80e2fe161d0..67107cada6e 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -3,12 +3,14 @@ from typing import Optional +from invokeai.app.invocations.baseinvocation import InvocationContext from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.model_records import ModelRecordServiceBase -from invokeai.backend.model_manager import AnyModelConfig, SubModelType +from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException +from invokeai.app.services.model_records import ModelRecordServiceBase, UnknownModelException +from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel, ModelCache, ModelConvertCache from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase -from invokeai.backend.model_manager.load.ram_cache import ModelCacheBase +from invokeai.backend.model_manager.load.model_cache import ModelCacheBase from invokeai.backend.util.logging import InvokeAILogger from .model_load_base import ModelLoadServiceBase @@ -21,7 +23,7 @@ def __init__( self, app_config: InvokeAIAppConfig, record_store: ModelRecordServiceBase, - ram_cache: Optional[ModelCacheBase] = None, + ram_cache: Optional[ModelCacheBase[AnyModel]] = None, convert_cache: Optional[ModelConvertCacheBase] = None, ): """Initialize the model load service.""" @@ -44,11 +46,104 @@ def __init__( ), ) - def load_model_by_key(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel: - """Given a model's key, load it and return the LoadedModel object.""" + def load_model_by_key( + self, + key: str, + submodel_type: Optional[SubModelType] = None, + context: Optional[InvocationContext] = None, + ) -> LoadedModel: + """ + Given a model's key, load it and return the LoadedModel object. + + :param key: Key of model config to be fetched. + :param submodel: For main (pipeline models), the submodel to fetch. + :param context: Invocation context used for event reporting + """ config = self._store.get_model(key) - return self.load_model_by_config(config, submodel_type) + return self.load_model_by_config(config, submodel_type, context) + + def load_model_by_attr( + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + submodel: Optional[SubModelType] = None, + context: Optional[InvocationContext] = None, + ) -> LoadedModel: + """ + Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object. + + This is provided for API compatability with the get_model() method + in the original model manager. However, note that LoadedModel is + not the same as the original ModelInfo that ws returned. + + :param model_name: Name of to be fetched. + :param base_model: Base model + :param model_type: Type of the model + :param submodel: For main (pipeline models), the submodel to fetch + :param context: The invocation context. + + Exceptions: UnknownModelException -- model with this key not known + NotImplementedException -- a model loader was not provided at initialization time + ValueError -- more than one model matches this combination + """ + configs = self._store.search_by_attr(model_name, base_model, model_type) + if len(configs) == 0: + raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model") + elif len(configs) > 1: + raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.") + else: + return self.load_model_by_key(configs[0].key, submodel) + + def load_model_by_config( + self, + model_config: AnyModelConfig, + submodel_type: Optional[SubModelType] = None, + context: Optional[InvocationContext] = None, + ) -> LoadedModel: + """ + Given a model's configuration, load it and return the LoadedModel object. + + :param model_config: Model configuration record (as returned by ModelRecordBase.get_model()) + :param submodel: For main (pipeline models), the submodel to fetch. + :param context: Invocation context used for event reporting + """ + if context: + self._emit_load_event( + context=context, + model_config=model_config, + ) + loaded_model = self._any_loader.load_model(model_config, submodel_type) + if context: + self._emit_load_event( + context=context, + model_config=model_config, + loaded=True, + ) + return loaded_model + + def _emit_load_event( + self, + context: InvocationContext, + model_config: AnyModelConfig, + loaded: Optional[bool] = False, + ) -> None: + if context.services.queue.is_canceled(context.graph_execution_state_id): + raise CanceledException() - def load_model_by_config(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: - """Given a model's configuration, load it and return the LoadedModel object.""" - return self._any_loader.load_model(config, submodel_type) + if not loaded: + context.services.events.emit_model_load_started( + queue_id=context.queue_id, + queue_item_id=context.queue_item_id, + queue_batch_id=context.queue_batch_id, + graph_execution_state_id=context.graph_execution_state_id, + model_config=model_config, + ) + else: + context.services.events.emit_model_load_completed( + queue_id=context.queue_id, + queue_item_id=context.queue_item_id, + queue_batch_id=context.queue_batch_id, + graph_execution_state_id=context.graph_execution_state_id, + model_config=model_config, + ) diff --git a/invokeai/app/services/model_manager/model_manager_base.py b/invokeai/app/services/model_manager/model_manager_base.py index c6e77fa163d..1116c82ff1f 100644 --- a/invokeai/app/services/model_manager/model_manager_base.py +++ b/invokeai/app/services/model_manager/model_manager_base.py @@ -2,9 +2,10 @@ from abc import ABC, abstractmethod -from pydantic import BaseModel, Field from typing_extensions import Self +from invokeai.app.services.invoker import Invoker + from ..config import InvokeAIAppConfig from ..download import DownloadQueueServiceBase from ..events.events_base import EventServiceBase @@ -14,12 +15,13 @@ from ..shared.sqlite.sqlite_database import SqliteDatabase -class ModelManagerServiceBase(BaseModel, ABC): +class ModelManagerServiceBase(ABC): """Abstract base class for the model manager service.""" - store: ModelRecordServiceBase = Field(description="An instance of the model record configuration service.") - install: ModelInstallServiceBase = Field(description="An instance of the model install service.") - load: ModelLoadServiceBase = Field(description="An instance of the model load service.") + # attributes: + # store: ModelRecordServiceBase = Field(description="An instance of the model record configuration service.") + # install: ModelInstallServiceBase = Field(description="An instance of the model install service.") + # load: ModelLoadServiceBase = Field(description="An instance of the model load service.") @classmethod @abstractmethod @@ -37,3 +39,29 @@ def build_model_manager( method simplifies the construction considerably. """ pass + + @property + @abstractmethod + def store(self) -> ModelRecordServiceBase: + """Return the ModelRecordServiceBase used to store and retrieve configuration records.""" + pass + + @property + @abstractmethod + def load(self) -> ModelLoadServiceBase: + """Return the ModelLoadServiceBase used to load models from their configuration records.""" + pass + + @property + @abstractmethod + def install(self) -> ModelInstallServiceBase: + """Return the ModelInstallServiceBase used to download and manipulate model files.""" + pass + + @abstractmethod + def start(self, invoker: Invoker) -> None: + pass + + @abstractmethod + def stop(self, invoker: Invoker) -> None: + pass diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index ad0fd66dbbd..028d4af6159 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -3,6 +3,7 @@ from typing_extensions import Self +from invokeai.app.services.invoker import Invoker from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache from invokeai.backend.model_manager.metadata import ModelMetadataStore from invokeai.backend.util.logging import InvokeAILogger @@ -10,9 +11,9 @@ from ..config import InvokeAIAppConfig from ..download import DownloadQueueServiceBase from ..events.events_base import EventServiceBase -from ..model_install import ModelInstallService -from ..model_load import ModelLoadService -from ..model_records import ModelRecordServiceSQL +from ..model_install import ModelInstallService, ModelInstallServiceBase +from ..model_load import ModelLoadService, ModelLoadServiceBase +from ..model_records import ModelRecordServiceBase, ModelRecordServiceSQL from ..shared.sqlite.sqlite_database import SqliteDatabase from .model_manager_base import ModelManagerServiceBase @@ -27,6 +28,38 @@ class ModelManagerService(ModelManagerServiceBase): model_manager.load -- Routines to load models into memory. """ + def __init__( + self, + store: ModelRecordServiceBase, + install: ModelInstallServiceBase, + load: ModelLoadServiceBase, + ): + self._store = store + self._install = install + self._load = load + + @property + def store(self) -> ModelRecordServiceBase: + return self._store + + @property + def install(self) -> ModelInstallServiceBase: + return self._install + + @property + def load(self) -> ModelLoadServiceBase: + return self._load + + def start(self, invoker: Invoker) -> None: + for service in [self._store, self._install, self._load]: + if hasattr(service, "start"): + service.start(invoker) + + def stop(self, invoker: Invoker) -> None: + for service in [self._store, self._install, self._load]: + if hasattr(service, "stop"): + service.stop(invoker) + @classmethod def build_model_manager( cls, diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index e00dd4169d5..e2e98c7e896 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -10,15 +10,12 @@ from pydantic import BaseModel, Field -from invokeai.app.invocations.baseinvocation import InvocationContext from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.backend.model_manager import ( AnyModelConfig, BaseModelType, - LoadedModel, ModelFormat, ModelType, - SubModelType, ) from invokeai.backend.model_manager.load import AnyModelLoader from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore @@ -111,52 +108,6 @@ def get_model(self, key: str) -> AnyModelConfig: """ pass - @abstractmethod - def load_model( - self, - key: str, - submodel: Optional[SubModelType] = None, - context: Optional[InvocationContext] = None, - ) -> LoadedModel: - """ - Load the indicated model into memory and return a LoadedModel object. - - :param key: Key of model config to be fetched. - :param submodel: For main (pipeline models), the submodel to fetch - :param context: Invocation context, used for event issuing. - - Exceptions: UnknownModelException -- model with this key not known - NotImplementedException -- a model loader was not provided at initialization time - """ - pass - - @abstractmethod - def load_model_by_attr( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: Optional[SubModelType] = None, - context: Optional[InvocationContext] = None, - ) -> LoadedModel: - """ - Load the indicated model into memory and return a LoadedModel object. - - This is provided for API compatability with the get_model() method - in the original model manager. However, note that LoadedModel is - not the same as the original ModelInfo that ws returned. - - :param model_name: Key of model config to be fetched. - :param base_model: Base model - :param model_type: Type of the model - :param submodel: For main (pipeline models), the submodel to fetch - :param context: The invocation context. - - Exceptions: UnknownModelException -- model with this key not known - NotImplementedException -- a model loader was not provided at initialization time - """ - pass - @property @abstractmethod def metadata_store(self) -> ModelMetadataStore: diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index 28a77b1b1ab..f48175351de 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -46,8 +46,6 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Set, Tuple, Union -from invokeai.app.invocations.baseinvocation import InvocationContext -from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.backend.model_manager.config import ( AnyModelConfig, @@ -55,9 +53,8 @@ ModelConfigFactory, ModelFormat, ModelType, - SubModelType, ) -from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel +from invokeai.backend.model_manager.load import AnyModelLoader from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException from ..shared.sqlite.sqlite_database import SqliteDatabase @@ -220,74 +217,6 @@ def get_model(self, key: str) -> AnyModelConfig: model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1]) return model - def load_model( - self, - key: str, - submodel: Optional[SubModelType], - context: Optional[InvocationContext] = None, - ) -> LoadedModel: - """ - Load the indicated model into memory and return a LoadedModel object. - - :param key: Key of model config to be fetched. - :param submodel: For main (pipeline models), the submodel to fetch. - :param context: Invocation context used for event reporting - - Exceptions: UnknownModelException -- model with this key not known - NotImplementedException -- a model loader was not provided at initialization time - """ - if not self._loader: - raise NotImplementedError(f"Class {self.__class__} was not initialized with a model loader") - # we can emit model loading events if we are executing with access to the invocation context - - model_config = self.get_model(key) - if context: - self._emit_load_event( - context=context, - model_config=model_config, - ) - loaded_model = self._loader.load_model(model_config, submodel) - if context: - self._emit_load_event( - context=context, - model_config=model_config, - loaded=True, - ) - return loaded_model - - def load_model_by_attr( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: Optional[SubModelType] = None, - context: Optional[InvocationContext] = None, - ) -> LoadedModel: - """ - Load the indicated model into memory and return a LoadedModel object. - - This is provided for API compatability with the get_model() method - in the original model manager. However, note that LoadedModel is - not the same as the original ModelInfo that ws returned. - - :param model_name: Key of model config to be fetched. - :param base_model: Base model - :param model_type: Type of the model - :param submodel: For main (pipeline models), the submodel to fetch - :param context: The invocation context. - - Exceptions: UnknownModelException -- model with this key not known - NotImplementedException -- a model loader was not provided at initialization time - ValueError -- more than one model matches this combination - """ - configs = self.search_by_attr(model_name, base_model, model_type) - if len(configs) == 0: - raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model") - elif len(configs) > 1: - raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.") - else: - return self.load_model(configs[0].key, submodel) - def exists(self, key: str) -> bool: """ Return True if a model with the indicated key exists in the databse. @@ -476,29 +405,3 @@ def _fixup(summary: Dict[str, str]) -> Dict[str, Union[str, int, Set[str]]]: return PaginatedResults( page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items ) - - def _emit_load_event( - self, - context: InvocationContext, - model_config: AnyModelConfig, - loaded: Optional[bool] = False, - ) -> None: - if context.services.queue.is_canceled(context.graph_execution_state_id): - raise CanceledException() - - if not loaded: - context.services.events.emit_model_load_started( - queue_id=context.queue_id, - queue_item_id=context.queue_item_id, - queue_batch_id=context.queue_batch_id, - graph_execution_state_id=context.graph_execution_state_id, - model_config=model_config, - ) - else: - context.services.events.emit_model_load_completed( - queue_id=context.queue_id, - queue_item_id=context.queue_item_id, - queue_batch_id=context.queue_batch_id, - graph_execution_state_id=context.graph_execution_state_id, - model_config=model_config, - ) diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_6.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_6.py index b4734445110..1f9ac56518c 100644 --- a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_6.py +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_6.py @@ -6,6 +6,7 @@ class Migration6Callback: def __call__(self, cursor: sqlite3.Cursor) -> None: self._recreate_model_triggers(cursor) + self._delete_ip_adapters(cursor) def _recreate_model_triggers(self, cursor: sqlite3.Cursor) -> None: """ @@ -26,6 +27,22 @@ def _recreate_model_triggers(self, cursor: sqlite3.Cursor) -> None: """ ) + def _delete_ip_adapters(self, cursor: sqlite3.Cursor) -> None: + """ + Delete all the IP adapters. + + The model manager will automatically find and re-add them after the migration + is done. This allows the manager to add the correct image encoder to their + configuration records. + """ + + cursor.execute( + """--sql + DELETE FROM model_config + WHERE type='ip_adapter'; + """ + ) + def build_migration_6() -> Migration: """ @@ -33,6 +50,8 @@ def build_migration_6() -> Migration: This migration does the following: - Adds the model_config_updated_at trigger if it does not exist + - Delete all ip_adapter models so that the model prober can find and + update with the correct image processor model. """ migration_6 = Migration( from_version=5, diff --git a/invokeai/backend/embeddings/model_patcher.py b/invokeai/backend/embeddings/model_patcher.py index 4725181b8ed..bee8909c311 100644 --- a/invokeai/backend/embeddings/model_patcher.py +++ b/invokeai/backend/embeddings/model_patcher.py @@ -64,7 +64,7 @@ def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tup def apply_lora_unet( cls, unet: UNet2DConditionModel, - loras: List[Tuple[LoRAModelRaw, float]], + loras: Iterator[Tuple[LoRAModelRaw, float]], ) -> None: with cls.apply_lora(unet, loras, "lora_unet_"): yield @@ -307,7 +307,7 @@ class ONNXModelPatcher: def apply_lora_unet( cls, unet: OnnxRuntimeModel, - loras: List[Tuple[LoRAModelRaw, float]], + loras: Iterator[Tuple[LoRAModelRaw, float]], ) -> None: with cls.apply_lora(unet, loras, "lora_unet_"): yield diff --git a/invokeai/backend/image_util/safety_checker.py b/invokeai/backend/image_util/safety_checker.py index b9649925e14..92ddef5ecc3 100644 --- a/invokeai/backend/image_util/safety_checker.py +++ b/invokeai/backend/image_util/safety_checker.py @@ -8,8 +8,8 @@ import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend import SilenceWarnings from invokeai.backend.util.devices import choose_torch_device +from invokeai.backend.util.silence_warnings import SilenceWarnings config = InvokeAIAppConfig.get_config() diff --git a/invokeai/backend/ip_adapter/ip_adapter.py b/invokeai/backend/ip_adapter/ip_adapter.py index 9176bf1f49f..b4706ea99c0 100644 --- a/invokeai/backend/ip_adapter/ip_adapter.py +++ b/invokeai/backend/ip_adapter/ip_adapter.py @@ -8,7 +8,6 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights -from invokeai.backend.model_management.models.base import calc_model_size_by_data from .resampler import Resampler @@ -124,6 +123,9 @@ def to(self, device: torch.device, dtype: Optional[torch.dtype] = None): self.attn_weights.to(device=self.device, dtype=self.dtype) def calc_size(self): + # workaround for circular import + from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data + return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights) def _init_image_proj_model(self, state_dict): diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 4534a4892fb..9f0f774b499 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -21,7 +21,7 @@ """ import time from enum import Enum -from typing import Literal, Optional, Type, Union, Class +from typing import Literal, Optional, Type, Union import torch from diffusers import ModelMixin @@ -335,7 +335,7 @@ def make_config( cls, model_data: Union[Dict[str, Any], AnyModelConfig], key: Optional[str] = None, - dest_class: Optional[Type[Class]] = None, + dest_class: Optional[Type[ModelConfigBase]] = None, timestamp: Optional[float] = None, ) -> AnyModelConfig: """ @@ -347,14 +347,17 @@ def make_config( :param dest_class: The config class to be returned. If not provided, will be selected automatically. """ + model: Optional[ModelConfigBase] = None if isinstance(model_data, ModelConfigBase): model = model_data elif dest_class: - model = dest_class.validate_python(model_data) + model = dest_class.model_validate(model_data) else: - model = AnyModelConfigValidator.validate_python(model_data) + # mypy doesn't typecheck TypeAdapters well? + model = AnyModelConfigValidator.validate_python(model_data) # type: ignore + assert model is not None if key: model.key = key if timestamp: model.last_modified = timestamp - return model + return model # type: ignore diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index 9d98ee30531..3d026af2269 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -18,8 +18,16 @@ from typing import Any, Callable, Dict, Optional, Tuple, Type from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType -from invokeai.backend.model_manager.config import AnyModel, VaeCheckpointConfig, VaeDiffusersConfig +from invokeai.backend.model_manager.config import ( + AnyModel, + AnyModelConfig, + BaseModelType, + ModelFormat, + ModelType, + SubModelType, + VaeCheckpointConfig, + VaeDiffusersConfig, +) from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase from invokeai.backend.util.logging import InvokeAILogger @@ -32,7 +40,7 @@ class LoadedModel: config: AnyModelConfig locker: ModelLockerBase - def __enter__(self) -> AnyModel: # I think load_file() always returns a dict + def __enter__(self) -> AnyModel: """Context entry.""" self.locker.lock() return self.model @@ -171,6 +179,10 @@ def register( def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]: cls._logger.debug(f"Registering class {subclass.__name__} to load models of type {base}/{type}/{format}") key = cls._to_registry_key(base, type, format) + if key in cls._registry: + raise Exception( + f"{subclass.__name__} is trying to register as a loader for {base}/{type}/{format}, but this type of model has already been registered by {cls._registry[key].__name__}" + ) cls._registry[key] = subclass return subclass diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index c1dfe729af7..df83c8320d9 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -169,7 +169,7 @@ def _get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelT raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e # This needs to be implemented in subclasses that handle checkpoints - def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path: + def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path: raise NotImplementedError # This needs to be implemented in the subclass diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index b1deb215b2b..98d6f34cead 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -246,7 +246,7 @@ def offload_unlocked_models(self, size_required: int) -> None: def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None: """Move model into the indicated device.""" - # These attributes are not in the base ModelMixin class but in derived classes. + # These attributes are not in the base ModelMixin class but in various derived classes. # Some models don't have these attributes, in which case they run in RAM/CPU. self.logger.debug(f"Called to move {cache_entry.key} to {target_device}") if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")): diff --git a/invokeai/backend/model_manager/load/model_loaders/controlnet.py b/invokeai/backend/model_manager/load/model_loaders/controlnet.py index e61e2b46a63..d446d079336 100644 --- a/invokeai/backend/model_manager/load/model_loaders/controlnet.py +++ b/invokeai/backend/model_manager/load/model_loaders/controlnet.py @@ -35,28 +35,28 @@ def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: else: return True - def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path: + def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path: if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}: raise Exception(f"Vae conversion not supported for model type: {config.base}") else: assert hasattr(config, "config") config_file = config.config - if weights_path.suffix == ".safetensors": - checkpoint = safetensors.torch.load_file(weights_path, device="cpu") + if model_path.suffix == ".safetensors": + checkpoint = safetensors.torch.load_file(model_path, device="cpu") else: - checkpoint = torch.load(weights_path, map_location="cpu") + checkpoint = torch.load(model_path, map_location="cpu") # sometimes weights are hidden under "state_dict", and sometimes not if "state_dict" in checkpoint: checkpoint = checkpoint["state_dict"] convert_controlnet_to_diffusers( - weights_path, + model_path, output_path, original_config_file=self._app_config.root_path / config_file, image_size=512, scan_needed=True, - from_safetensors=weights_path.suffix == ".safetensors", + from_safetensors=model_path.suffix == ".safetensors", ) return output_path diff --git a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py index 03c26f3a0c0..114e317f3c6 100644 --- a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +++ b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py @@ -12,8 +12,9 @@ ModelType, SubModelType, ) -from invokeai.backend.model_manager.load.load_base import AnyModelLoader -from invokeai.backend.model_manager.load.load_default import ModelLoader + +from ..load_base import AnyModelLoader +from ..load_default import ModelLoader @AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers) diff --git a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py index a963e8403b9..23b4e1fccd6 100644 --- a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +++ b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py @@ -65,7 +65,7 @@ def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: else: return True - def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path: + def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path: assert isinstance(config, MainCheckpointConfig) variant = config.variant base = config.base @@ -75,9 +75,9 @@ def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path config_file = config.config - self._logger.info(f"Converting {weights_path} to diffusers format") + self._logger.info(f"Converting {model_path} to diffusers format") convert_ckpt_to_diffusers( - weights_path, + model_path, output_path, model_type=self.model_base_to_model_type[base], model_version=base, @@ -86,7 +86,7 @@ def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path extract_ema=True, scan_needed=True, pipeline_class=pipeline_class, - from_safetensors=weights_path.suffix == ".safetensors", + from_safetensors=model_path.suffix == ".safetensors", precision=self._torch_dtype, load_safety_checker=False, ) diff --git a/invokeai/backend/model_manager/load/model_loaders/vae.py b/invokeai/backend/model_manager/load/model_loaders/vae.py index 882ae055771..3983ea75950 100644 --- a/invokeai/backend/model_manager/load/model_loaders/vae.py +++ b/invokeai/backend/model_manager/load/model_loaders/vae.py @@ -37,7 +37,7 @@ def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: else: return True - def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path: + def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path: # TO DO: check whether sdxl VAE models convert. if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}: raise Exception(f"Vae conversion not supported for model type: {config.base}") @@ -46,10 +46,10 @@ def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path "v1-inference.yaml" if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml" ) - if weights_path.suffix == ".safetensors": - checkpoint = safetensors.torch.load_file(weights_path, device="cpu") + if model_path.suffix == ".safetensors": + checkpoint = safetensors.torch.load_file(model_path, device="cpu") else: - checkpoint = torch.load(weights_path, map_location="cpu") + checkpoint = torch.load(model_path, map_location="cpu") # sometimes weights are hidden under "state_dict", and sometimes not if "state_dict" in checkpoint: diff --git a/invokeai/backend/model_manager/load/model_util.py b/invokeai/backend/model_manager/load/model_util.py index 3f2d22595e2..c55eee48fa5 100644 --- a/invokeai/backend/model_manager/load/model_util.py +++ b/invokeai/backend/model_manager/load/model_util.py @@ -65,7 +65,7 @@ def calc_model_size_by_fs(model_path: Path, subfolder: Optional[str] = None, var bit8_files = {f for f in all_files if ".8bit." in f.name or ".8bit-" in f.name} other_files = set(all_files) - fp16_files - bit8_files - if variant is None: + if not variant: # ModelRepoVariant.DEFAULT evaluates to empty string for compatability with HF files = other_files elif variant == "fp16": files = fp16_files diff --git a/invokeai/backend/model_manager/search.py b/invokeai/backend/model_manager/search.py index a54938fdd5c..f7e1e1bed76 100644 --- a/invokeai/backend/model_manager/search.py +++ b/invokeai/backend/model_manager/search.py @@ -22,11 +22,12 @@ def find_main_models(model: Path) -> bool: import os from abc import ABC, abstractmethod +from logging import Logger from pathlib import Path from typing import Callable, Optional, Set, Union from pydantic import BaseModel, Field -from logging import Logger + from invokeai.backend.util.logging import InvokeAILogger default_logger: Logger = InvokeAILogger.get_logger() diff --git a/invokeai/backend/stable_diffusion/schedulers/__init__.py b/invokeai/backend/stable_diffusion/schedulers/__init__.py index a4e9dbf9dad..0b780d3ee27 100644 --- a/invokeai/backend/stable_diffusion/schedulers/__init__.py +++ b/invokeai/backend/stable_diffusion/schedulers/__init__.py @@ -1 +1,3 @@ from .schedulers import SCHEDULER_MAP # noqa: F401 + +__all__ = ["SCHEDULER_MAP"] diff --git a/invokeai/frontend/install/model_install.py b/invokeai/frontend/install/model_install.py index 22b132370e6..20b630dfc62 100644 --- a/invokeai/frontend/install/model_install.py +++ b/invokeai/frontend/install/model_install.py @@ -513,7 +513,7 @@ def select_and_download_models(opt: Namespace) -> None: """Prompt user for install/delete selections and execute.""" precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device())) # unsure how to avoid a typing complaint in the next line: config.precision is an enumerated Literal - config.precision = precision # type: ignore + config.precision = precision install_helper = InstallHelper(config, logger) installer = install_helper.installer diff --git a/tests/aa_nodes/test_graph_execution_state.py b/tests/aa_nodes/test_graph_execution_state.py index 27d2d2230a3..f839a4a8785 100644 --- a/tests/aa_nodes/test_graph_execution_state.py +++ b/tests/aa_nodes/test_graph_execution_state.py @@ -62,9 +62,7 @@ def mock_services() -> InvocationServices: invocation_cache=MemoryInvocationCache(max_cache_size=0), logger=logging, # type: ignore model_manager=None, # type: ignore - model_records=None, # type: ignore download_queue=None, # type: ignore - model_install=None, # type: ignore names=None, # type: ignore performance_statistics=InvocationStatsService(), processor=DefaultInvocationProcessor(), diff --git a/tests/aa_nodes/test_invoker.py b/tests/aa_nodes/test_invoker.py index 437ea0f00d3..774f7501dc2 100644 --- a/tests/aa_nodes/test_invoker.py +++ b/tests/aa_nodes/test_invoker.py @@ -65,9 +65,7 @@ def mock_services() -> InvocationServices: invocation_cache=MemoryInvocationCache(max_cache_size=0), logger=logging, # type: ignore model_manager=None, # type: ignore - model_records=None, # type: ignore download_queue=None, # type: ignore - model_install=None, # type: ignore names=None, # type: ignore performance_statistics=InvocationStatsService(), processor=DefaultInvocationProcessor(), diff --git a/tests/backend/model_manager_2/model_loading/test_model_load.py b/tests/backend/model_manager_2/model_loading/test_model_load.py new file mode 100644 index 00000000000..a7a64e91ac0 --- /dev/null +++ b/tests/backend/model_manager_2/model_loading/test_model_load.py @@ -0,0 +1,22 @@ +""" +Test model loading +""" + +from pathlib import Path + +from invokeai.app.services.model_install import ModelInstallServiceBase +from invokeai.backend.embeddings.textual_inversion import TextualInversionModelRaw +from invokeai.backend.model_manager.load import AnyModelLoader +from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403 + + +def test_loading(mm2_installer: ModelInstallServiceBase, mm2_loader: AnyModelLoader, embedding_file: Path): + store = mm2_installer.record_store + matches = store.search_by_attr(model_name="test_embedding") + assert len(matches) == 0 + key = mm2_installer.register_path(embedding_file) + loaded_model = mm2_loader.load_model(store.get_model(key)) + assert loaded_model is not None + assert loaded_model.config.key == key + with loaded_model as model: + assert isinstance(model, TextualInversionModelRaw) diff --git a/tests/backend/model_manager_2/model_manager_2_fixtures.py b/tests/backend/model_manager_2/model_manager_2_fixtures.py index d6d091befea..d85eab67dd3 100644 --- a/tests/backend/model_manager_2/model_manager_2_fixtures.py +++ b/tests/backend/model_manager_2/model_manager_2_fixtures.py @@ -20,6 +20,7 @@ ModelFormat, ModelType, ) +from invokeai.backend.model_manager.load import AnyModelLoader, ModelCache, ModelConvertCache from invokeai.backend.model_manager.metadata import ModelMetadataStore from invokeai.backend.util.logging import InvokeAILogger from tests.backend.model_manager_2.model_metadata.metadata_examples import ( @@ -89,6 +90,16 @@ def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig: return app_config +@pytest.fixture +def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceSQL) -> AnyModelLoader: + logger = InvokeAILogger.get_logger(config=mm2_app_config) + ram_cache = ModelCache( + logger=logger, max_cache_size=mm2_app_config.ram_cache_size, max_vram_cache_size=mm2_app_config.vram_cache_size + ) + convert_cache = ModelConvertCache(mm2_app_config.models_convert_cache_path) + return AnyModelLoader(app_config=mm2_app_config, logger=logger, ram_cache=ram_cache, convert_cache=convert_cache) + + @pytest.fixture def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL: logger = InvokeAILogger.get_logger(config=mm2_app_config) From 13a9ea35b5884409f7ca5dedddaa5091daf0eaf7 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 11 Feb 2024 23:37:49 -0500 Subject: [PATCH 083/100] add back the `heuristic_import()` method and extend repo_ids to arbitrary file paths --- docs/contributing/MODEL_MANAGER.md | 52 ++++++++++++-- invokeai/app/api/routers/model_manager_v2.py | 70 ++++++++++++++++++- invokeai/app/api_app.py | 1 - .../model_install/model_install_base.py | 39 ++++++++++- .../model_install/model_install_default.py | 43 +++++++++++- .../model_manager/util/select_hf_files.py | 6 ++ 6 files changed, 199 insertions(+), 12 deletions(-) diff --git a/docs/contributing/MODEL_MANAGER.md b/docs/contributing/MODEL_MANAGER.md index 39220f4ba89..959b7f9733c 100644 --- a/docs/contributing/MODEL_MANAGER.md +++ b/docs/contributing/MODEL_MANAGER.md @@ -446,6 +446,44 @@ required parameters: Once initialized, the installer will provide the following methods: +#### install_job = installer.heuristic_import(source, [config], [access_token]) + +This is a simplified interface to the installer which takes a source +string, an optional model configuration dictionary and an optional +access token. + +The `source` is a string that can be any of these forms + +1. A path on the local filesystem (`C:\\users\\fred\\model.safetensors`) +2. A Url pointing to a single downloadable model file (`https://civitai.com/models/58390/detail-tweaker-lora-lora`) +3. A HuggingFace repo_id with any of the following formats: + - `model/name` -- entire model + - `model/name:fp32` -- entire model, using the fp32 variant + - `model/name:fp16:vae` -- vae submodel, using the fp16 variant + - `model/name::vae` -- vae submodel, using default precision + - `model/name:fp16:path/to/model.safetensors` -- an individual model file, fp16 variant + - `model/name::path/to/model.safetensors` -- an individual model file, default variant + +Note that by specifying a relative path to the top of the HuggingFace +repo, you can download and install arbitrary models files. + +The variant, if not provided, will be automatically filled in with +`fp32` if the user has requested full precision, and `fp16` +otherwise. If a variant that does not exist is requested, then the +method will install whatever HuggingFace returns as its default +revision. + +`config` is an optional dict of values that will override the +autoprobed values for model type, base, scheduler prediction type, and +so forth. See [Model configuration and +probing](#Model-configuration-and-probing) for details. + +`access_token` is an optional access token for accessing resources +that need authentication. + +The method will return a `ModelInstallJob`. This object is discussed +at length in the following section. + #### install_job = installer.import_model() The `import_model()` method is the core of the installer. The @@ -464,9 +502,10 @@ source2 = LocalModelSource(path='/opt/models/sushi_diffusers') # a local dif source3 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5') # a repo_id source4 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', subfolder='vae') # a subfolder within a repo_id source5 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', variant='fp16') # a named variant of a HF model +source6 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', subfolder='OrangeMix/OrangeMix1.ckpt') # path to an individual model file -source6 = URLModelSource(url='https://civitai.com/api/download/models/63006') # model located at a URL -source7 = URLModelSource(url='https://civitai.com/api/download/models/63006', access_token='letmein') # with an access token +source7 = URLModelSource(url='https://civitai.com/api/download/models/63006') # model located at a URL +source8 = URLModelSource(url='https://civitai.com/api/download/models/63006', access_token='letmein') # with an access token for source in [source1, source2, source3, source4, source5, source6, source7]: install_job = installer.install_model(source) @@ -522,7 +561,6 @@ can be passed to `import_model()`. attributes returned by the model prober. See the section below for details. - #### LocalModelSource This is used for a model that is located on a locally-accessible Posix @@ -715,7 +753,7 @@ and `cancelled`, as well as `in_terminal_state`. The last will return True if the job is in the complete, errored or cancelled states. -#### Model confguration and probing +#### Model configuration and probing The install service uses the `invokeai.backend.model_manager.probe` module during import to determine the model's type, base type, and @@ -1106,7 +1144,7 @@ job = queue.create_download_job( event_handlers=[my_handler1, my_handler2], # if desired start=True, ) - ``` +``` The `filename` argument forces the downloader to use the specified name for the file rather than the name provided by the remote source, @@ -1427,9 +1465,9 @@ set of keys to the corresponding model config objects. Find all model metadata records that have the given author and return a set of keys to the corresponding model config objects. -# The remainder of this documentation is provisional, pending implementation of the Load service +*** -## Let's get loaded, the lowdown on ModelLoadService +## The Lowdown on the ModelLoadService The `ModelLoadService` is responsible for loading a named model into memory so that it can be used for inference. Despite the fact that it diff --git a/invokeai/app/api/routers/model_manager_v2.py b/invokeai/app/api/routers/model_manager_v2.py index 4fc785e4f7a..4482edfa0f6 100644 --- a/invokeai/app/api/routers/model_manager_v2.py +++ b/invokeai/app/api/routers/model_manager_v2.py @@ -251,9 +251,75 @@ async def add_model_record( return result +@model_manager_v2_router.post( + "/heuristic_import", + operation_id="heuristic_import_model", + responses={ + 201: {"description": "The model imported successfully"}, + 415: {"description": "Unrecognized file/folder format"}, + 424: {"description": "The model appeared to import successfully, but could not be found in the model manager"}, + 409: {"description": "There is already a model corresponding to this path or repo_id"}, + }, + status_code=201, +) +async def heuristic_import( + source: str, + config: Optional[Dict[str, Any]] = Body( + description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ", + default=None, + ), + access_token: Optional[str] = None, +) -> ModelInstallJob: + """Install a model using a string identifier. + + `source` can be any of the following. + + 1. A path on the local filesystem ('C:\\users\\fred\\model.safetensors') + 2. A Url pointing to a single downloadable model file + 3. A HuggingFace repo_id with any of the following formats: + - model/name + - model/name:fp16:vae + - model/name::vae -- use default precision + - model/name:fp16:path/to/model.safetensors + - model/name::path/to/model.safetensors + + `config` is an optional dict containing model configuration values that will override + the ones that are probed automatically. + + `access_token` is an optional access token for use with Urls that require + authentication. + + Models will be downloaded, probed, configured and installed in a + series of background threads. The return object has `status` attribute + that can be used to monitor progress. + + See the documentation for `import_model_record` for more information on + interpreting the job information returned by this route. + """ + logger = ApiDependencies.invoker.services.logger + + try: + installer = ApiDependencies.invoker.services.model_manager.install + result: ModelInstallJob = installer.heuristic_import( + source=source, + config=config, + ) + logger.info(f"Started installation of {source}") + except UnknownModelException as e: + logger.error(str(e)) + raise HTTPException(status_code=424, detail=str(e)) + except InvalidModelException as e: + logger.error(str(e)) + raise HTTPException(status_code=415) + except ValueError as e: + logger.error(str(e)) + raise HTTPException(status_code=409, detail=str(e)) + return result + + @model_manager_v2_router.post( "/import", - operation_id="import_model_record", + operation_id="import_model", responses={ 201: {"description": "The model imported successfully"}, 415: {"description": "Unrecognized file/folder format"}, @@ -269,7 +335,7 @@ async def import_model( default=None, ), ) -> ModelInstallJob: - """Add a model using its local path, repo_id, or remote URL. + """Install a model using its local path, repo_id, or remote URL. Models will be downloaded, probed, configured and installed in a series of background threads. The return object has `status` attribute diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 851cbc8160e..1831b54c13c 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -49,7 +49,6 @@ download_queue, images, model_manager_v2, - models, session_queue, sessions, utilities, diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index 635cb154d64..943cdf1157f 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -127,8 +127,8 @@ def proper_repo_id(cls, v: str) -> str: # noqa D102 def __str__(self) -> str: """Return string version of repoid when string rep needed.""" base: str = self.repo_id + base += f":{self.variant or ''}" base += f":{self.subfolder}" if self.subfolder else "" - base += f" ({self.variant})" if self.variant else "" return base @@ -324,6 +324,43 @@ def install_path( :returns id: The string ID of the registered model. """ + @abstractmethod + def heuristic_import( + self, + source: str, + config: Optional[Dict[str, Any]] = None, + access_token: Optional[str] = None, + ) -> ModelInstallJob: + r"""Install the indicated model using heuristics to interpret user intentions. + + :param source: String source + :param config: Optional dict. Any fields in this dict + will override corresponding autoassigned probe fields in the + model's config record as described in `import_model()`. + :param access_token: Optional access token for remote sources. + + The source can be: + 1. A local file path in posix() format (`/foo/bar` or `C:\foo\bar`) + 2. An http or https URL (`https://foo.bar/foo`) + 3. A HuggingFace repo_id (`foo/bar`, `foo/bar:fp16`, `foo/bar:fp16:vae`) + + We extend the HuggingFace repo_id syntax to include the variant and the + subfolder or path. The following are acceptable alternatives: + stabilityai/stable-diffusion-v4 + stabilityai/stable-diffusion-v4:fp16 + stabilityai/stable-diffusion-v4:fp16:vae + stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors + stabilityai/stable-diffusion-v4:onnx:vae + + Because a local file path can look like a huggingface repo_id, the logic + first checks whether the path exists on disk, and if not, it is treated as + a parseable huggingface repo. + + The previous support for recursing into a local folder and loading all model-like files + has been removed. + """ + pass + @abstractmethod def import_model( self, diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index d32af4a513d..df73fcb8cbe 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -50,6 +50,7 @@ ModelInstallJob, ModelInstallServiceBase, ModelSource, + StringLikeSource, URLModelSource, ) @@ -177,6 +178,34 @@ def install_path( info, ) + def heuristic_import( + self, + source: str, + config: Optional[Dict[str, Any]] = None, + access_token: Optional[str] = None, + ) -> ModelInstallJob: + variants = "|".join(ModelRepoVariant.__members__.values()) + hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$" + source_obj: Optional[StringLikeSource] = None + + if Path(source).exists(): # A local file or directory + source_obj = LocalModelSource(path=Path(source)) + elif match := re.match(hf_repoid_re, source): + source_obj = HFModelSource( + repo_id=match.group(1), + variant=match.group(2) if match.group(2) else None, # pass None rather than '' + subfolder=Path(match.group(3)) if match.group(3) else None, + access_token=access_token, + ) + elif re.match(r"^https?://[^/]+", source): + source_obj = URLModelSource( + url=AnyHttpUrl(source), + access_token=access_token, + ) + else: + raise ValueError(f"Unsupported model source: '{source}'") + return self.import_model(source_obj, config) + def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102 similar_jobs = [x for x in self.list_jobs() if x.source == source and not x.in_terminal_state] if similar_jobs: @@ -571,6 +600,8 @@ def _import_remote_model( # TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up. # Currently the tmpdir isn't automatically removed at exit because it is # being held in a daemon thread. + if len(remote_files) == 0: + raise ValueError(f"{source}: No downloadable files found") tmpdir = Path( mkdtemp( dir=self._app_config.models_path, @@ -586,6 +617,16 @@ def _import_remote_model( bytes=0, total_bytes=0, ) + # In the event that there is a subfolder specified in the source, + # we need to remove it from the destination path in order to avoid + # creating unwanted subfolders + if hasattr(source, "subfolder") and source.subfolder: + root = Path(remote_files[0].path.parts[0]) + subfolder = root / source.subfolder + else: + root = Path(".") + subfolder = Path(".") + # we remember the path up to the top of the tmpdir so that it may be # removed safely at the end of the install process. install_job._install_tmpdir = tmpdir @@ -595,7 +636,7 @@ def _import_remote_model( self._logger.debug(f"remote_files={remote_files}") for model_file in remote_files: url = model_file.url - path = model_file.path + path = root / model_file.path.relative_to(subfolder) self._logger.info(f"Downloading {url} => {path}") install_job.total_bytes += model_file.size assert hasattr(source, "access_token") diff --git a/invokeai/backend/model_manager/util/select_hf_files.py b/invokeai/backend/model_manager/util/select_hf_files.py index a894d915de6..2fd7a3721ab 100644 --- a/invokeai/backend/model_manager/util/select_hf_files.py +++ b/invokeai/backend/model_manager/util/select_hf_files.py @@ -36,6 +36,11 @@ def filter_files( """ variant = variant or ModelRepoVariant.DEFAULT paths: List[Path] = [] + root = files[0].parts[0] + + # if the subfolder is a single file, then bypass the selection and just return it + if subfolder and subfolder.suffix in [".safetensors", ".bin", ".onnx", ".xml", ".pth", ".pt", ".ckpt", ".msgpack"]: + return [root / subfolder] # Start by filtering on model file extensions, discarding images, docs, etc for file in files: @@ -61,6 +66,7 @@ def filter_files( # limit search to subfolder if requested if subfolder: + subfolder = root / subfolder paths = [x for x in paths if x.parent == Path(subfolder)] # _filter_by_variant uniquifies the paths and returns a set From 46c8ce9fed21698e97221b7c3bebecadc15117cd Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 12 Feb 2024 14:27:17 -0500 Subject: [PATCH 084/100] add a JIT download_and_cache() call to the model installer --- docs/contributing/MODEL_MANAGER.md | 40 +++++++++++++++ .../app/services/download/download_base.py | 13 +++++ .../app/services/download/download_default.py | 15 +++++- .../model_install/model_install_base.py | 34 ++++++++++++- .../model_install/model_install_default.py | 49 ++++++++++++++++++- .../convert_cache/convert_cache_default.py | 8 ++- 6 files changed, 154 insertions(+), 5 deletions(-) diff --git a/docs/contributing/MODEL_MANAGER.md b/docs/contributing/MODEL_MANAGER.md index 959b7f9733c..b711c654de8 100644 --- a/docs/contributing/MODEL_MANAGER.md +++ b/docs/contributing/MODEL_MANAGER.md @@ -792,6 +792,14 @@ returns a list of completed jobs. The optional `timeout` argument will return from the call if jobs aren't completed in the specified time. An argument of 0 (the default) will block indefinitely. +#### jobs = installer.wait_for_job(job, [timeout]) + +Like `wait_for_installs()`, but block until a specific job has +completed or errored, and then return the job. The optional `timeout` +argument will return from the call if the job doesn't complete in the +specified time. An argument of 0 (the default) will block +indefinitely. + #### jobs = installer.list_jobs() Return a list of all active and complete `ModelInstallJobs`. @@ -854,6 +862,31 @@ This method is similar to `unregister()`, but also unconditionally deletes the corresponding model weights file(s), regardless of whether they are inside or outside the InvokeAI models hierarchy. + +#### path = installer.download_and_cache(remote_source, [access_token], [timeout]) + +This utility routine will download the model file located at source, +cache it, and return the path to the cached file. It does not attempt +to determine the model type, probe its configuration values, or +register it with the models database. + +You may provide an access token if the remote source requires +authorization. The call will block indefinitely until the file is +completely downloaded, cancelled or raises an error of some sort. If +you provide a timeout (in seconds), the call will raise a +`TimeoutError` exception if the download hasn't completed in the +specified period. + +You may use this mechanism to request any type of file, not just a +model. The file will be stored in a subdirectory of +`INVOKEAI_ROOT/models/.cache`. If the requested file is found in the +cache, its path will be returned without redownloading it. + +Be aware that the models cache is cleared of infrequently-used files +and directories at regular intervals when the size of the cache +exceeds the value specified in Invoke's `convert_cache` configuration +variable. + #### List[str]=installer.scan_directory(scan_dir: Path, install: bool) This method will recursively scan the directory indicated in @@ -1187,6 +1220,13 @@ queue or was not created by this queue. This method will block until all the active jobs in the queue have reached a terminal state (completed, errored or cancelled). +#### queue.wait_for_job(job, [timeout]) + +This method will block until the indicated job has reached a terminal +state (completed, errored or cancelled). If the optional timeout is +provided, the call will block for at most timeout seconds, and raise a +TimeoutError otherwise. + #### jobs = queue.list_jobs() This will return a list of all jobs, including ones that have not yet diff --git a/invokeai/app/services/download/download_base.py b/invokeai/app/services/download/download_base.py index f854f64f585..2ac13b825fe 100644 --- a/invokeai/app/services/download/download_base.py +++ b/invokeai/app/services/download/download_base.py @@ -260,3 +260,16 @@ def cancel_job(self, job: DownloadJob) -> None: def join(self) -> None: """Wait until all jobs are off the queue.""" pass + + @abstractmethod + def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob: + """Wait until the indicated download job has reached a terminal state. + + This will block until the indicated install job has completed, + been cancelled, or errored out. + + :param job: The job to wait on. + :param timeout: Wait up to indicated number of seconds. Raise a TimeoutError if + the job hasn't completed within the indicated time. + """ + pass diff --git a/invokeai/app/services/download/download_default.py b/invokeai/app/services/download/download_default.py index 7613c0893fc..f740c500873 100644 --- a/invokeai/app/services/download/download_default.py +++ b/invokeai/app/services/download/download_default.py @@ -4,6 +4,7 @@ import os import re import threading +import time import traceback from pathlib import Path from queue import Empty, PriorityQueue @@ -52,6 +53,7 @@ def __init__( self._next_job_id = 0 self._queue = PriorityQueue() self._stop_event = threading.Event() + self._job_completed_event = threading.Event() self._worker_pool = set() self._lock = threading.Lock() self._logger = InvokeAILogger.get_logger("DownloadQueueService") @@ -188,6 +190,16 @@ def cancel_all_jobs(self) -> None: if not job.in_terminal_state: self.cancel_job(job) + def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob: + """Block until the indicated job has reached terminal state, or when timeout limit reached.""" + start = time.time() + while not job.in_terminal_state: + if self._job_completed_event.wait(timeout=5): # in case we miss an event + self._job_completed_event.clear() + if timeout > 0 and time.time() - start > timeout: + raise TimeoutError("Timeout exceeded") + return job + def _start_workers(self, max_workers: int) -> None: """Start the requested number of worker threads.""" self._stop_event.clear() @@ -223,6 +235,7 @@ def _download_next_item(self) -> None: finally: job.job_ended = get_iso_timestamp() + self._job_completed_event.set() # signal a change to terminal state self._queue.task_done() self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.") @@ -407,7 +420,7 @@ def _cleanup_cancelled_job(self, job: DownloadJob) -> None: # Example on_progress event handler to display a TQDM status bar # Activate with: -# download_service.download('http://foo.bar/baz', '/tmp', on_progress=TqdmProgress().job_update +# download_service.download(DownloadJob('http://foo.bar/baz', '/tmp', on_progress=TqdmProgress().update)) class TqdmProgress(object): """TQDM-based progress bar object to use in on_progress handlers.""" diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index 943cdf1157f..39ea8c4a0d1 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -422,6 +422,18 @@ def prune_jobs(self) -> None: def cancel_job(self, job: ModelInstallJob) -> None: """Cancel the indicated job.""" + @abstractmethod + def wait_for_job(self, job: ModelInstallJob, timeout: int = 0) -> ModelInstallJob: + """Wait for the indicated job to reach a terminal state. + + This will block until the indicated install job has completed, + been cancelled, or errored out. + + :param job: The job to wait on. + :param timeout: Wait up to indicated number of seconds. Raise a TimeoutError if + the job hasn't completed within the indicated time. + """ + @abstractmethod def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: """ @@ -431,7 +443,8 @@ def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: completed, been cancelled, or errored out. :param timeout: Wait up to indicated number of seconds. Raise an Exception('timeout') if - installs do not complete within the indicated time. + installs do not complete within the indicated time. A timeout of zero (the default) + will block indefinitely until the installs complete. """ @abstractmethod @@ -447,3 +460,22 @@ def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: @abstractmethod def sync_to_config(self) -> None: """Synchronize models on disk to those in the model record database.""" + + @abstractmethod + def download_and_cache(self, source: Union[str, AnyHttpUrl], access_token: Optional[str] = None) -> Path: + """ + Download the model file located at source to the models cache and return its Path. + + :param source: A Url or a string that can be converted into one. + :param access_token: Optional access token to access restricted resources. + + The model file will be downloaded into the system-wide model cache + (`models/.cache`) if it isn't already there. Note that the model cache + is periodically cleared of infrequently-used entries when the model + converter runs. + + Note that this doesn't automaticallly install or register the model, but is + intended for use by nodes that need access to models that aren't directly + supported by InvokeAI. The downloading process takes advantage of the download queue + to avoid interrupting other operations. + """ diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index df73fcb8cbe..414e3007157 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -17,7 +17,7 @@ from requests import Session from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase +from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invoker import Invoker from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase, ModelRecordServiceSQL @@ -87,6 +87,7 @@ def __init__( self._lock = threading.Lock() self._stop_event = threading.Event() self._downloads_changed_event = threading.Event() + self._install_completed_event = threading.Event() self._download_queue = download_queue self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {} self._running = False @@ -241,6 +242,17 @@ def get_job_by_id(self, id: int) -> ModelInstallJob: # noqa D102 assert isinstance(jobs[0], ModelInstallJob) return jobs[0] + def wait_for_job(self, job: ModelInstallJob, timeout: int = 0) -> ModelInstallJob: + """Block until the indicated job has reached terminal state, or when timeout limit reached.""" + start = time.time() + while not job.in_terminal_state: + if self._install_completed_event.wait(timeout=5): # in case we miss an event + self._install_completed_event.clear() + if timeout > 0 and time.time() - start > timeout: + raise TimeoutError("Timeout exceeded") + return job + + # TODO: Better name? Maybe wait_for_jobs()? Maybe too easily confused with above def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: # noqa D102 """Block until all installation jobs are done.""" start = time.time() @@ -248,7 +260,7 @@ def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: # noqa if self._downloads_changed_event.wait(timeout=5): # in case we miss an event self._downloads_changed_event.clear() if timeout > 0 and time.time() - start > timeout: - raise Exception("Timeout exceeded") + raise TimeoutError("Timeout exceeded") self._install_queue.join() return self._install_jobs @@ -302,6 +314,38 @@ def unconditionally_delete(self, key: str) -> None: # noqa D102 path.unlink() self.unregister(key) + def download_and_cache( + self, + source: Union[str, AnyHttpUrl], + access_token: Optional[str] = None, + timeout: int = 0, + ) -> Path: + """Download the model file located at source to the models cache and return its Path.""" + model_hash = sha256(str(source).encode("utf-8")).hexdigest()[0:32] + model_path = self._app_config.models_convert_cache_path / model_hash + + # We expect the cache directory to contain one and only one downloaded file. + # We don't know the file's name in advance, as it is set by the download + # content-disposition header. + if model_path.exists(): + contents = [x for x in model_path.iterdir() if x.is_file()] + if len(contents) > 0: + return contents[0] + + model_path.mkdir(parents=True, exist_ok=True) + job = self._download_queue.download( + source=AnyHttpUrl(str(source)), + dest=model_path, + access_token=access_token, + on_progress=TqdmProgress().update, + ) + self._download_queue.wait_for_job(job, timeout) + if job.complete: + assert job.download_path is not None + return job.download_path + else: + raise Exception(job.error) + # -------------------------------------------------------------------------------------------- # Internal functions that manage the installer threads # -------------------------------------------------------------------------------------------- @@ -365,6 +409,7 @@ def _install_next_item(self) -> None: # if this is an install of a remote file, then clean up the temporary directory if job._install_tmpdir is not None: rmtree(job._install_tmpdir) + self._install_completed_event.set() self._install_queue.task_done() self._logger.info("Install thread exiting") diff --git a/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py b/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py index 4c361258d90..84f4f76299a 100644 --- a/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py +++ b/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py @@ -53,7 +53,13 @@ def by_atime(path: Path) -> float: sentinel = path / config if sentinel.exists(): return sentinel.stat().st_atime - return 0.0 + + # no sentinel file found! - pick the most recent file in the directory + try: + atimes = sorted([x.stat().st_atime for x in path.iterdir() if x.is_file()], reverse=True) + return atimes[0] + except IndexError: + return 0.0 # sort by last access time - least accessed files will be at the end lru_models = sorted(self._cache_path.iterdir(), key=by_atime, reverse=True) From 0845a0ed847f156646e8647d1cc0dfc57d55055b Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 12 Feb 2024 21:25:42 -0500 Subject: [PATCH 085/100] add route for model conversion from safetensors to diffusers - Begin to add SwaggerUI documentation for AnyModelConfig and other discriminated Unions. --- invokeai/app/api/routers/model_manager_v2.py | 80 ++++++++++++++++++- .../model_install/model_install_default.py | 6 +- .../services/model_load/model_load_base.py | 14 +++- .../services/model_load/model_load_default.py | 12 ++- .../model_records/model_records_base.py | 7 -- .../model_records/model_records_sql.py | 10 +-- .../backend/model_manager/load/load_base.py | 5 ++ 7 files changed, 113 insertions(+), 21 deletions(-) diff --git a/invokeai/app/api/routers/model_manager_v2.py b/invokeai/app/api/routers/model_manager_v2.py index 4482edfa0f6..8d31c6f286b 100644 --- a/invokeai/app/api/routers/model_manager_v2.py +++ b/invokeai/app/api/routers/model_manager_v2.py @@ -2,6 +2,7 @@ """FastAPI route for model configuration records.""" import pathlib +import shutil from hashlib import sha1 from random import randbytes from typing import Any, Dict, List, Optional, Set @@ -24,8 +25,10 @@ from invokeai.backend.model_manager.config import ( AnyModelConfig, BaseModelType, + MainCheckpointConfig, ModelFormat, ModelType, + SubModelType, ) from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata @@ -318,7 +321,7 @@ async def heuristic_import( @model_manager_v2_router.post( - "/import", + "/install", operation_id="import_model", responses={ 201: {"description": "The model imported successfully"}, @@ -490,6 +493,81 @@ async def sync_models_to_config() -> Response: return Response(status_code=204) +@model_manager_v2_router.put( + "/convert/{key}", + operation_id="convert_model", + responses={ + 200: {"description": "Model converted successfully"}, + 400: {"description": "Bad request"}, + 404: {"description": "Model not found"}, + 409: {"description": "There is already a model registered at this location"}, + }, +) +async def convert_model( + key: str = Path(description="Unique key of the safetensors main model to convert to diffusers format."), +) -> AnyModelConfig: + """ + Permanently convert a model into diffusers format, replacing the safetensors version. + Note that the key and model hash will change. Use the model configuration record returned + by this call to get the new values. + """ + logger = ApiDependencies.invoker.services.logger + loader = ApiDependencies.invoker.services.model_manager.load + store = ApiDependencies.invoker.services.model_manager.store + installer = ApiDependencies.invoker.services.model_manager.install + + try: + model_config = store.get_model(key) + except UnknownModelException as e: + logger.error(str(e)) + raise HTTPException(status_code=424, detail=str(e)) + + if not isinstance(model_config, MainCheckpointConfig): + logger.error(f"The model with key {key} is not a main checkpoint model.") + raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.") + + # loading the model will convert it into a cached diffusers file + loader.load_model_by_config(model_config, submodel_type=SubModelType.Scheduler) + + # Get the path of the converted model from the loader + cache_path = loader.convert_cache.cache_path(key) + assert cache_path.exists() + + # temporarily rename the original safetensors file so that there is no naming conflict + original_name = model_config.name + model_config.name = f"{original_name}.DELETE" + store.update_model(key, config=model_config) + + # install the diffusers + try: + new_key = installer.install_path( + cache_path, + config={ + "name": original_name, + "description": model_config.description, + "original_hash": model_config.original_hash, + "source": model_config.source, + }, + ) + except DuplicateModelException as e: + logger.error(str(e)) + raise HTTPException(status_code=409, detail=str(e)) + + # get the original metadata + if orig_metadata := store.get_metadata(key): + store.metadata_store.add_metadata(new_key, orig_metadata) + + # delete the original safetensors file + installer.delete(key) + + # delete the cached version + shutil.rmtree(cache_path) + + # return the config record for the new diffusers directory + new_config: AnyModelConfig = store.get_model(new_key) + return new_config + + @model_manager_v2_router.put( "/merge", operation_id="merge", diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 414e3007157..20a85a82a14 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -162,8 +162,10 @@ def install_path( config["source"] = model_path.resolve().as_posix() info: AnyModelConfig = self._probe_model(Path(model_path), config) - old_hash = info.original_hash - dest_path = self.app_config.models_path / info.base.value / info.type.value / model_path.name + old_hash = info.current_hash + dest_path = ( + self.app_config.models_path / info.base.value / info.type.value / (config.get("name") or model_path.name) + ) try: new_path = self._copy_model(model_path, dest_path) except FileExistsError as excp: diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py index f298d98ce6d..45eaf4652fb 100644 --- a/invokeai/app/services/model_load/model_load_base.py +++ b/invokeai/app/services/model_load/model_load_base.py @@ -5,8 +5,10 @@ from typing import Optional from invokeai.app.invocations.baseinvocation import InvocationContext -from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelType, SubModelType +from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType from invokeai.backend.model_manager.load import LoadedModel +from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase +from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase class ModelLoadServiceBase(ABC): @@ -70,3 +72,13 @@ def load_model_by_attr( NotImplementedException -- a model loader was not provided at initialization time ValueError -- more than one model matches this combination """ + + @property + @abstractmethod + def ram_cache(self) -> ModelCacheBase[AnyModel]: + """Return the RAM cache used by this loader.""" + + @property + @abstractmethod + def convert_cache(self) -> ModelConvertCacheBase: + """Return the checkpoint convert cache used by this loader.""" diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index 67107cada6e..a6ccd5afbc3 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -10,7 +10,7 @@ from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel, ModelCache, ModelConvertCache from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase -from invokeai.backend.model_manager.load.model_cache import ModelCacheBase +from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase from invokeai.backend.util.logging import InvokeAILogger from .model_load_base import ModelLoadServiceBase @@ -46,6 +46,16 @@ def __init__( ), ) + @property + def ram_cache(self) -> ModelCacheBase[AnyModel]: + """Return the RAM cache used by this loader.""" + return self._any_loader.ram_cache + + @property + def convert_cache(self) -> ModelConvertCacheBase: + """Return the checkpoint convert cache used by this loader.""" + return self._any_loader.convert_cache + def load_model_by_key( self, key: str, diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index e2e98c7e896..b2eacc524b7 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -17,7 +17,6 @@ ModelFormat, ModelType, ) -from invokeai.backend.model_manager.load import AnyModelLoader from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore @@ -195,12 +194,6 @@ def search_by_attr( """ pass - @property - @abstractmethod - def loader(self) -> Optional[AnyModelLoader]: - """Return the model loader used by this instance.""" - pass - def all_models(self) -> List[AnyModelConfig]: """Return all the model configs in the database.""" return self.search_by_attr() diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index f48175351de..84a14123838 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -54,7 +54,6 @@ ModelFormat, ModelType, ) -from invokeai.backend.model_manager.load import AnyModelLoader from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException from ..shared.sqlite.sqlite_database import SqliteDatabase @@ -70,28 +69,21 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): """Implementation of the ModelConfigStore ABC using a SQL database.""" - def __init__(self, db: SqliteDatabase, loader: Optional[AnyModelLoader] = None): + def __init__(self, db: SqliteDatabase): """ Initialize a new object from preexisting sqlite3 connection and threading lock objects. :param db: Sqlite connection object - :param loader: Initialized model loader object (optional) """ super().__init__() self._db = db self._cursor = db.conn.cursor() - self._loader = loader @property def db(self) -> SqliteDatabase: """Return the underlying database.""" return self._db - @property - def loader(self) -> Optional[AnyModelLoader]: - """Return the model loader used by this instance.""" - return self._loader - def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig: """ Add a model to the database. diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index 3d026af2269..5f392ada75e 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -117,6 +117,11 @@ def ram_cache(self) -> ModelCacheBase[AnyModel]: """Return the RAM cache associated used by the loaders.""" return self._ram_cache + @property + def convert_cache(self) -> ModelConvertCacheBase: + """Return the convert cache associated used by the loaders.""" + return self._convert_cache + def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: """ Return a model given its configuration. From 631f6cae1992be22e69805a3962eb48ab19d2b6d Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 13 Feb 2024 00:26:49 -0500 Subject: [PATCH 086/100] fix a number of typechecking errors --- invokeai/app/api/routers/download_queue.py | 6 +- invokeai/app/api/routers/model_manager_v2.py | 69 ++++++++++++++++--- invokeai/app/invocations/ip_adapter.py | 4 +- invokeai/app/invocations/model.py | 8 +-- invokeai/app/services/config/config_base.py | 11 +-- invokeai/app/services/config/config_common.py | 2 +- .../app/services/download/download_default.py | 12 ++-- invokeai/app/util/misc.py | 8 +-- .../backend/model_manager/load/load_base.py | 13 ++-- .../load/model_cache/model_cache_default.py | 4 +- .../metadata/fetch/fetch_base.py | 4 +- invokeai/backend/model_manager/probe.py | 2 +- invokeai/backend/model_manager/search.py | 6 +- 13 files changed, 101 insertions(+), 48 deletions(-) diff --git a/invokeai/app/api/routers/download_queue.py b/invokeai/app/api/routers/download_queue.py index 2dba376c181..a6e53c7a5c4 100644 --- a/invokeai/app/api/routers/download_queue.py +++ b/invokeai/app/api/routers/download_queue.py @@ -36,7 +36,7 @@ async def list_downloads() -> List[DownloadJob]: 400: {"description": "Bad request"}, }, ) -async def prune_downloads(): +async def prune_downloads() -> Response: """Prune completed and errored jobs.""" queue = ApiDependencies.invoker.services.download_queue queue.prune_jobs() @@ -87,7 +87,7 @@ async def get_download_job( ) async def cancel_download_job( id: int = Path(description="ID of the download job to cancel."), -): +) -> Response: """Cancel a download job using its ID.""" try: queue = ApiDependencies.invoker.services.download_queue @@ -105,7 +105,7 @@ async def cancel_download_job( 204: {"description": "Download jobs have been cancelled"}, }, ) -async def cancel_all_download_jobs(): +async def cancel_all_download_jobs() -> Response: """Cancel all download jobs.""" ApiDependencies.invoker.services.download_queue.cancel_all_jobs() return Response(status_code=204) diff --git a/invokeai/app/api/routers/model_manager_v2.py b/invokeai/app/api/routers/model_manager_v2.py index 8d31c6f286b..029c6207072 100644 --- a/invokeai/app/api/routers/model_manager_v2.py +++ b/invokeai/app/api/routers/model_manager_v2.py @@ -9,7 +9,7 @@ from fastapi import Body, Path, Query, Response from fastapi.routing import APIRouter -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field from starlette.exceptions import HTTPException from typing_extensions import Annotated @@ -37,6 +37,35 @@ model_manager_v2_router = APIRouter(prefix="/v2/models", tags=["model_manager_v2"]) +example_model_output = { + "path": "sd-1/main/openjourney", + "name": "openjourney", + "base": "sd-1", + "type": "main", + "format": "diffusers", + "key": "3a0e45ff858926fd4a63da630688b1e1", + "original_hash": "1c12f18fb6e403baef26fb9d720fbd2f", + "current_hash": "1c12f18fb6e403baef26fb9d720fbd2f", + "description": "sd-1 main model openjourney", + "source": "/opt/invokeai/models/sd-1/main/openjourney", + "last_modified": 1707794711, + "vae": "/opt/invokeai/models/sd-1/vae/vae-ft-mse-840000-ema-pruned_fp16.safetensors", + "variant": "normal", + "prediction_type": "epsilon", + "repo_variant": "fp16", +} + +example_model_input = { + "path": "base/type/name", + "name": "model_name", + "base": "sd-1", + "type": "main", + "format": "diffusers", + "description": "Model description", + "vae": None, + "variant": "normal", +} + class ModelsList(BaseModel): """Return list of configs.""" @@ -88,7 +117,10 @@ async def list_model_records( "/i/{key}", operation_id="get_model_record", responses={ - 200: {"description": "Success"}, + 200: { + "description": "The model configuration was retrieved successfully", + "content": {"application/json": {"example": example_model_output}}, + }, 400: {"description": "Bad request"}, 404: {"description": "The model could not be found"}, }, @@ -165,18 +197,22 @@ async def search_by_metadata_tags( "/i/{key}", operation_id="update_model_record", responses={ - 200: {"description": "The model was updated successfully"}, + 200: { + "description": "The model was updated successfully", + "content": {"application/json": {"example": example_model_output}}, + }, 400: {"description": "Bad request"}, 404: {"description": "The model could not be found"}, 409: {"description": "There is already a model corresponding to the new name"}, }, status_code=200, - response_model=AnyModelConfig, ) async def update_model_record( key: Annotated[str, Path(description="Unique key of model")], - info: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")], -) -> AnyModelConfig: + info: Annotated[ + AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input) + ], +) -> Annotated[AnyModelConfig, Field(example="this is neat")]: """Update model contents with a new config. If the model name or base fields are changed, then the model is renamed.""" logger = ApiDependencies.invoker.services.logger record_store = ApiDependencies.invoker.services.model_manager.store @@ -225,7 +261,10 @@ async def del_model_record( "/i/", operation_id="add_model_record", responses={ - 201: {"description": "The model added successfully"}, + 201: { + "description": "The model added successfully", + "content": {"application/json": {"example": example_model_output}}, + }, 409: {"description": "There is already a model corresponding to this path or repo_id"}, 415: {"description": "Unrecognized file/folder format"}, }, @@ -270,6 +309,7 @@ async def heuristic_import( config: Optional[Dict[str, Any]] = Body( description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ", default=None, + example={"name": "modelT", "description": "antique cars"}, ), access_token: Optional[str] = None, ) -> ModelInstallJob: @@ -497,7 +537,10 @@ async def sync_models_to_config() -> Response: "/convert/{key}", operation_id="convert_model", responses={ - 200: {"description": "Model converted successfully"}, + 200: { + "description": "Model converted successfully", + "content": {"application/json": {"example": example_model_output}}, + }, 400: {"description": "Bad request"}, 404: {"description": "Model not found"}, 409: {"description": "There is already a model registered at this location"}, @@ -571,6 +614,15 @@ async def convert_model( @model_manager_v2_router.put( "/merge", operation_id="merge", + responses={ + 200: { + "description": "Model converted successfully", + "content": {"application/json": {"example": example_model_output}}, + }, + 400: {"description": "Bad request"}, + 404: {"description": "Model not found"}, + 409: {"description": "There is already a model registered at this location"}, + }, ) async def merge( keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3), @@ -596,7 +648,6 @@ async def merge( interp: Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum] merge_dest_directory: Specify a directory to store the merged model in [models directory] """ - print(f"here i am, keys={keys}") logger = ApiDependencies.invoker.services.logger try: logger.info(f"Merging models: {keys} into {merge_dest_directory or ''}/{merged_model_name}") diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index f64b3266bbb..01124f62f3c 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -90,10 +90,10 @@ def validate_begin_end_step_percent(self) -> Self: def invoke(self, context: InvocationContext) -> IPAdapterOutput: # Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model. - ip_adapter_info = context.services.model_records.get_model(self.ip_adapter_model.key) + ip_adapter_info = context.services.model_manager.store.get_model(self.ip_adapter_model.key) image_encoder_model_id = ip_adapter_info.image_encoder_model_id image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip() - image_encoder_models = context.services.model_records.search_by_attr( + image_encoder_models = context.services.model_manager.store.search_by_attr( model_name=image_encoder_model_name, base_model=BaseModelType.Any, model_type=ModelType.CLIPVision ) assert len(image_encoder_models) == 1 diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index fa6e8b98da0..f78425c6ee3 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -103,7 +103,7 @@ def invoke(self, context: InvocationContext) -> ModelLoaderOutput: key = self.model.key # TODO: not found exceptions - if not context.services.model_records.exists(key): + if not context.services.model_manager.store.exists(key): raise Exception(f"Unknown model {key}") return ModelLoaderOutput( @@ -172,7 +172,7 @@ def invoke(self, context: InvocationContext) -> LoraLoaderOutput: lora_key = self.lora.key - if not context.services.model_records.exists(lora_key): + if not context.services.model_manager.store.exists(lora_key): raise Exception(f"Unkown lora: {lora_key}!") if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras): @@ -252,7 +252,7 @@ def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: lora_key = self.lora.key - if not context.services.model_records.exists(lora_key): + if not context.services.model_manager.store.exists(lora_key): raise Exception(f"Unknown lora: {lora_key}!") if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras): @@ -318,7 +318,7 @@ class VaeLoaderInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> VAEOutput: key = self.vae_model.key - if not context.services.model_records.exists(key): + if not context.services.model_manager.store.exists(key): raise Exception(f"Unkown vae: {key}!") return VAEOutput(vae=VaeField(vae=ModelInfo(key=key))) diff --git a/invokeai/app/services/config/config_base.py b/invokeai/app/services/config/config_base.py index a304b38a955..983df6b4684 100644 --- a/invokeai/app/services/config/config_base.py +++ b/invokeai/app/services/config/config_base.py @@ -27,11 +27,11 @@ class InvokeAISettings(BaseSettings): """Runtime configuration settings in which default values are read from an omegaconf .yaml file.""" initconf: ClassVar[Optional[DictConfig]] = None - argparse_groups: ClassVar[Dict] = {} + argparse_groups: ClassVar[Dict[str, Any]] = {} model_config = SettingsConfigDict(env_file_encoding="utf-8", arbitrary_types_allowed=True, case_sensitive=True) - def parse_args(self, argv: Optional[list] = sys.argv[1:]): + def parse_args(self, argv: Optional[List[str]] = sys.argv[1:]) -> None: """Call to parse command-line arguments.""" parser = self.get_parser() opt, unknown_opts = parser.parse_known_args(argv) @@ -68,7 +68,7 @@ def to_yaml(self) -> str: return OmegaConf.to_yaml(conf) @classmethod - def add_parser_arguments(cls, parser): + def add_parser_arguments(cls, parser) -> None: """Dynamically create arguments for a settings parser.""" if "type" in get_type_hints(cls): settings_stanza = get_args(get_type_hints(cls)["type"])[0] @@ -117,7 +117,8 @@ def cmd_name(cls, command_field: str = "type") -> str: """Return the category of a setting.""" hints = get_type_hints(cls) if command_field in hints: - return get_args(hints[command_field])[0] + result: str = get_args(hints[command_field])[0] + return result else: return "Uncategorized" @@ -158,7 +159,7 @@ def _excluded_from_yaml(cls) -> List[str]: ] @classmethod - def add_field_argument(cls, command_parser, name: str, field, default_override=None): + def add_field_argument(cls, command_parser, name: str, field, default_override=None) -> None: """Add the argparse arguments for a setting parser.""" field_type = get_type_hints(cls).get(name) default = ( diff --git a/invokeai/app/services/config/config_common.py b/invokeai/app/services/config/config_common.py index d11bcabcf9c..27a0f859c23 100644 --- a/invokeai/app/services/config/config_common.py +++ b/invokeai/app/services/config/config_common.py @@ -21,7 +21,7 @@ class PagingArgumentParser(argparse.ArgumentParser): It also supports reading defaults from an init file. """ - def print_help(self, file=None): + def print_help(self, file=None) -> None: text = self.format_help() pydoc.pager(text) diff --git a/invokeai/app/services/download/download_default.py b/invokeai/app/services/download/download_default.py index f740c500873..7008f8ed741 100644 --- a/invokeai/app/services/download/download_default.py +++ b/invokeai/app/services/download/download_default.py @@ -8,12 +8,12 @@ import traceback from pathlib import Path from queue import Empty, PriorityQueue -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Set import requests from pydantic.networks import AnyHttpUrl from requests import HTTPError -from tqdm import tqdm +from tqdm import tqdm, std from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.util.misc import get_iso_timestamp @@ -49,12 +49,12 @@ def __init__( :param max_parallel_dl: Number of simultaneous downloads allowed [5]. :param requests_session: Optional requests.sessions.Session object, for unit tests. """ - self._jobs = {} + self._jobs: Dict[int, DownloadJob] = {} self._next_job_id = 0 - self._queue = PriorityQueue() + self._queue: PriorityQueue[DownloadJob] = PriorityQueue() self._stop_event = threading.Event() self._job_completed_event = threading.Event() - self._worker_pool = set() + self._worker_pool: Set[threading.Thread] = set() self._lock = threading.Lock() self._logger = InvokeAILogger.get_logger("DownloadQueueService") self._event_bus = event_bus @@ -424,7 +424,7 @@ def _cleanup_cancelled_job(self, job: DownloadJob) -> None: class TqdmProgress(object): """TQDM-based progress bar object to use in on_progress handlers.""" - _bars: Dict[int, tqdm] # the tqdm object + _bars: Dict[int, tqdm] # type: ignore _last: Dict[int, int] # last bytes downloaded def __init__(self) -> None: # noqa D107 diff --git a/invokeai/app/util/misc.py b/invokeai/app/util/misc.py index 910b05d8dde..da431929dbe 100644 --- a/invokeai/app/util/misc.py +++ b/invokeai/app/util/misc.py @@ -5,7 +5,7 @@ import numpy as np -def get_timestamp(): +def get_timestamp() -> int: return int(datetime.datetime.now(datetime.timezone.utc).timestamp()) @@ -20,16 +20,16 @@ def get_datetime_from_iso_timestamp(iso_timestamp: str) -> datetime.datetime: SEED_MAX = np.iinfo(np.uint32).max -def get_random_seed(): +def get_random_seed() -> int: rng = np.random.default_rng(seed=None) return int(rng.integers(0, SEED_MAX)) -def uuid_string(): +def uuid_string() -> str: res = uuid.uuid4() return str(res) -def is_optional(value: typing.Any): +def is_optional(value: typing.Any) -> bool: """Checks if a value is typed as Optional. Note that Optional is sugar for Union[x, None].""" return typing.get_origin(value) is typing.Union and type(None) in typing.get_args(value) diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index 5f392ada75e..7649dee762b 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -22,6 +22,7 @@ AnyModel, AnyModelConfig, BaseModelType, + ModelConfigBase, ModelFormat, ModelType, SubModelType, @@ -70,7 +71,7 @@ def __init__( pass @abstractmethod - def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: + def load_model(self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None) -> LoadedModel: """ Return a model given its confguration. @@ -122,7 +123,7 @@ def convert_cache(self) -> ModelConvertCacheBase: """Return the convert cache associated used by the loaders.""" return self._convert_cache - def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: + def load_model(self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None) -> LoadedModel: """ Return a model given its configuration. @@ -144,8 +145,8 @@ def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) @classmethod def get_implementation( - cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] - ) -> Tuple[Type[ModelLoaderBase], AnyModelConfig, Optional[SubModelType]]: + cls, config: ModelConfigBase, submodel_type: Optional[SubModelType] + ) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]: """Get subclass of ModelLoaderBase registered to handle base and type.""" # We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type) @@ -161,8 +162,8 @@ def get_implementation( @classmethod def _handle_subtype_overrides( - cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] - ) -> Tuple[AnyModelConfig, Optional[SubModelType]]: + cls, config: ModelConfigBase, submodel_type: Optional[SubModelType] + ) -> Tuple[ModelConfigBase, Optional[SubModelType]]: if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None: model_path = Path(config.vae) config_class = ( diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index 98d6f34cead..786396062cf 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -34,8 +34,8 @@ from invokeai.backend.util.devices import choose_torch_device from invokeai.backend.util.logging import InvokeAILogger -from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase -from .model_locker import ModelLocker, ModelLockerBase +from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase +from .model_locker import ModelLocker if choose_torch_device() == torch.device("mps"): from torch import mps diff --git a/invokeai/backend/model_manager/metadata/fetch/fetch_base.py b/invokeai/backend/model_manager/metadata/fetch/fetch_base.py index d628ab5c178..5d75493b92f 100644 --- a/invokeai/backend/model_manager/metadata/fetch/fetch_base.py +++ b/invokeai/backend/model_manager/metadata/fetch/fetch_base.py @@ -20,7 +20,7 @@ from invokeai.backend.model_manager import ModelRepoVariant -from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator +from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator, BaseMetadata class ModelMetadataFetchBase(ABC): @@ -62,5 +62,5 @@ def from_id(self, id: str, variant: Optional[ModelRepoVariant] = None) -> AnyMod @classmethod def from_json(cls, json: str) -> AnyModelRepoMetadata: """Given the JSON representation of the metadata, return the corresponding Pydantic object.""" - metadata = AnyModelRepoMetadataValidator.validate_json(json) + metadata: BaseMetadata = AnyModelRepoMetadataValidator.validate_json(json) # type: ignore return metadata diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index e7d21c578fd..2c2066d7c52 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -166,7 +166,7 @@ def probe( fields["original_hash"] = fields.get("original_hash") or hash fields["current_hash"] = fields.get("current_hash") or hash - if format_type == ModelFormat.Diffusers: + if format_type == ModelFormat.Diffusers and hasattr(probe, "get_repo_variant"): fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant() # additional fields needed for main and controlnet models diff --git a/invokeai/backend/model_manager/search.py b/invokeai/backend/model_manager/search.py index f7e1e1bed76..0ead22b743f 100644 --- a/invokeai/backend/model_manager/search.py +++ b/invokeai/backend/model_manager/search.py @@ -116,9 +116,9 @@ class ModelSearch(ModelSearchBase): # returns all models that have 'anime' in the path """ - models_found: Set[Path] = Field(default=None) - scanned_dirs: Set[Path] = Field(default=None) - pruned_paths: Set[Path] = Field(default=None) + models_found: Optional[Set[Path]] = Field(default=None) + scanned_dirs: Optional[Set[Path]] = Field(default=None) + pruned_paths: Optional[Set[Path]] = Field(default=None) def search_started(self) -> None: self.models_found = set() From 3e82f63c7eae6f5df2ea5ee63e0f4e0c93a3cfb1 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 14 Feb 2024 11:10:50 -0500 Subject: [PATCH 087/100] improve swagger documentation --- invokeai/app/api/routers/model_manager_v2.py | 214 ++++++++++++------ .../app/services/download/download_default.py | 2 +- invokeai/backend/model_manager/config.py | 16 +- 3 files changed, 159 insertions(+), 73 deletions(-) diff --git a/invokeai/app/api/routers/model_manager_v2.py b/invokeai/app/api/routers/model_manager_v2.py index 029c6207072..2471e0d8c9b 100644 --- a/invokeai/app/api/routers/model_manager_v2.py +++ b/invokeai/app/api/routers/model_manager_v2.py @@ -9,7 +9,7 @@ from fastapi import Body, Path, Query, Response from fastapi.routing import APIRouter -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict from starlette.exceptions import HTTPException from typing_extensions import Annotated @@ -37,51 +37,102 @@ model_manager_v2_router = APIRouter(prefix="/v2/models", tags=["model_manager_v2"]) -example_model_output = { - "path": "sd-1/main/openjourney", - "name": "openjourney", + +class ModelsList(BaseModel): + """Return list of configs.""" + + models: List[AnyModelConfig] + + model_config = ConfigDict(use_enum_values=True) + + +class ModelTagSet(BaseModel): + """Return tags for a set of models.""" + + key: str + name: str + author: str + tags: Set[str] + + +############################################################################## +# These are example inputs and outputs that are used in places where Swagger +# is unable to generate a correct example. +############################################################################## +example_model_config = { + "path": "string", + "name": "string", "base": "sd-1", "type": "main", - "format": "diffusers", - "key": "3a0e45ff858926fd4a63da630688b1e1", - "original_hash": "1c12f18fb6e403baef26fb9d720fbd2f", - "current_hash": "1c12f18fb6e403baef26fb9d720fbd2f", - "description": "sd-1 main model openjourney", - "source": "/opt/invokeai/models/sd-1/main/openjourney", - "last_modified": 1707794711, - "vae": "/opt/invokeai/models/sd-1/vae/vae-ft-mse-840000-ema-pruned_fp16.safetensors", + "format": "checkpoint", + "config": "string", + "key": "string", + "original_hash": "string", + "current_hash": "string", + "description": "string", + "source": "string", + "last_modified": 0, + "vae": "string", "variant": "normal", "prediction_type": "epsilon", "repo_variant": "fp16", + "upcast_attention": False, + "ztsnr_training": False, } example_model_input = { - "path": "base/type/name", + "path": "/path/to/model", "name": "model_name", "base": "sd-1", "type": "main", - "format": "diffusers", + "format": "checkpoint", + "config": "configs/stable-diffusion/v1-inference.yaml", "description": "Model description", "vae": None, "variant": "normal", } +example_model_metadata = { + "name": "ip_adapter_sd_image_encoder", + "author": "InvokeAI", + "tags": [ + "transformers", + "safetensors", + "clip_vision_model", + "endpoints_compatible", + "region:us", + "has_space", + "license:apache-2.0", + ], + "files": [ + { + "url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/README.md", + "path": "ip_adapter_sd_image_encoder/README.md", + "size": 628, + "sha256": None, + }, + { + "url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/config.json", + "path": "ip_adapter_sd_image_encoder/config.json", + "size": 560, + "sha256": None, + }, + { + "url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/model.safetensors", + "path": "ip_adapter_sd_image_encoder/model.safetensors", + "size": 2528373448, + "sha256": "6ca9667da1ca9e0b0f75e46bb030f7e011f44f86cbfb8d5a36590fcd7507b030", + }, + ], + "type": "huggingface", + "id": "InvokeAI/ip_adapter_sd_image_encoder", + "tag_dict": {"license": "apache-2.0"}, + "last_modified": "2023-09-23T17:33:25Z", +} -class ModelsList(BaseModel): - """Return list of configs.""" - - models: List[AnyModelConfig] - - model_config = ConfigDict(use_enum_values=True) - - -class ModelTagSet(BaseModel): - """Return tags for a set of models.""" - - key: str - name: str - author: str - tags: Set[str] +############################################################################## +# ROUTES +############################################################################## @model_manager_v2_router.get( @@ -119,7 +170,7 @@ async def list_model_records( responses={ 200: { "description": "The model configuration was retrieved successfully", - "content": {"application/json": {"example": example_model_output}}, + "content": {"application/json": {"example": example_model_config}}, }, 400: {"description": "Bad request"}, 404: {"description": "The model could not be found"}, @@ -137,7 +188,7 @@ async def get_model_record( raise HTTPException(status_code=404, detail=str(e)) -@model_manager_v2_router.get("/meta", operation_id="list_model_summary") +@model_manager_v2_router.get("/summary", operation_id="list_model_summary") async def list_model_summary( page: int = Query(default=0, description="The page to get"), per_page: int = Query(default=10, description="The number of models per page"), @@ -153,7 +204,10 @@ async def list_model_summary( "/meta/i/{key}", operation_id="get_model_metadata", responses={ - 200: {"description": "Success"}, + 200: { + "description": "The model metadata was retrieved successfully", + "content": {"application/json": {"example": example_model_metadata}}, + }, 400: {"description": "Bad request"}, 404: {"description": "No metadata available"}, }, @@ -199,7 +253,7 @@ async def search_by_metadata_tags( responses={ 200: { "description": "The model was updated successfully", - "content": {"application/json": {"example": example_model_output}}, + "content": {"application/json": {"example": example_model_config}}, }, 400: {"description": "Bad request"}, 404: {"description": "The model could not be found"}, @@ -212,7 +266,7 @@ async def update_model_record( info: Annotated[ AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input) ], -) -> Annotated[AnyModelConfig, Field(example="this is neat")]: +) -> AnyModelConfig: """Update model contents with a new config. If the model name or base fields are changed, then the model is renamed.""" logger = ApiDependencies.invoker.services.logger record_store = ApiDependencies.invoker.services.model_manager.store @@ -263,7 +317,7 @@ async def del_model_record( responses={ 201: { "description": "The model added successfully", - "content": {"application/json": {"example": example_model_output}}, + "content": {"application/json": {"example": example_model_config}}, }, 409: {"description": "There is already a model corresponding to this path or repo_id"}, 415: {"description": "Unrecognized file/folder format"}, @@ -271,7 +325,9 @@ async def del_model_record( status_code=201, ) async def add_model_record( - config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")], + config: Annotated[ + AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input) + ], ) -> AnyModelConfig: """Add a model using the configuration information appropriate for its type.""" logger = ApiDependencies.invoker.services.logger @@ -389,32 +445,38 @@ async def import_model( appropriate value: * To install a local path using LocalModelSource, pass a source of form: - `{ + ``` + { "type": "local", "path": "/path/to/model", "inplace": false - }` - The "inplace" flag, if true, will register the model in place in its - current filesystem location. Otherwise, the model will be copied - into the InvokeAI models directory. + } + ``` + The "inplace" flag, if true, will register the model in place in its + current filesystem location. Otherwise, the model will be copied + into the InvokeAI models directory. * To install a HuggingFace repo_id using HFModelSource, pass a source of form: - `{ + ``` + { "type": "hf", "repo_id": "stabilityai/stable-diffusion-2.0", "variant": "fp16", "subfolder": "vae", "access_token": "f5820a918aaf01" - }` - The `variant`, `subfolder` and `access_token` fields are optional. + } + ``` + The `variant`, `subfolder` and `access_token` fields are optional. * To install a remote model using an arbitrary URL, pass: - `{ + ``` + { "type": "url", "url": "http://www.civitai.com/models/123456", "access_token": "f5820a918aaf01" - }` - The `access_token` field is optonal + } + ``` + The `access_token` field is optonal The model's configuration record will be probed and filled in automatically. To override the default guesses, pass "metadata" @@ -423,9 +485,9 @@ async def import_model( Installation occurs in the background. Either use list_model_install_jobs() to poll for completion, or listen on the event bus for the following events: - "model_install_running" - "model_install_completed" - "model_install_error" + * "model_install_running" + * "model_install_completed" + * "model_install_error" On successful completion, the event's payload will contain the field "key" containing the installed ID of the model. On an error, the event's payload @@ -459,7 +521,25 @@ async def import_model( operation_id="list_model_install_jobs", ) async def list_model_install_jobs() -> List[ModelInstallJob]: - """Return list of model install jobs.""" + """Return the list of model install jobs. + + Install jobs have a numeric `id`, a `status`, and other fields that provide information on + the nature of the job and its progress. The `status` is one of: + + * "waiting" -- Job is waiting in the queue to run + * "downloading" -- Model file(s) are downloading + * "running" -- Model has downloaded and the model probing and registration process is running + * "completed" -- Installation completed successfully + * "error" -- An error occurred. Details will be in the "error_type" and "error" fields. + * "cancelled" -- Job was cancelled before completion. + + Once completed, information about the model such as its size, base + model, type, and metadata can be retrieved from the `config_out` + field. For multi-file models such as diffusers, information on individual files + can be retrieved from `download_parts`. + + See the example and schema below for more information. + """ jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_manager.install.list_jobs() return jobs @@ -473,7 +553,10 @@ async def list_model_install_jobs() -> List[ModelInstallJob]: }, ) async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob: - """Return model install job corresponding to the given source.""" + """ + Return model install job corresponding to the given source. See the documentation for 'List Model Install Jobs' + for information on the format of the return value. + """ try: result: ModelInstallJob = ApiDependencies.invoker.services.model_manager.install.get_job_by_id(id) return result @@ -539,7 +622,7 @@ async def sync_models_to_config() -> Response: responses={ 200: { "description": "Model converted successfully", - "content": {"application/json": {"example": example_model_output}}, + "content": {"application/json": {"example": example_model_config}}, }, 400: {"description": "Bad request"}, 404: {"description": "Model not found"}, @@ -551,8 +634,8 @@ async def convert_model( ) -> AnyModelConfig: """ Permanently convert a model into diffusers format, replacing the safetensors version. - Note that the key and model hash will change. Use the model configuration record returned - by this call to get the new values. + Note that during the conversion process the key and model hash will change. + The return value is the model configuration for the converted model. """ logger = ApiDependencies.invoker.services.logger loader = ApiDependencies.invoker.services.model_manager.load @@ -617,7 +700,7 @@ async def convert_model( responses={ 200: { "description": "Model converted successfully", - "content": {"application/json": {"example": example_model_output}}, + "content": {"application/json": {"example": example_model_config}}, }, 400: {"description": "Bad request"}, 404: {"description": "Model not found"}, @@ -639,14 +722,17 @@ async def merge( ), ) -> AnyModelConfig: """ - Merge diffusers models. - - keys: List of 2-3 model keys to merge together. All models must use the same base type. - merged_model_name: Name for the merged model [Concat model names] - alpha: Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5] - force: If true, force the merge even if the models were generated by different versions of the diffusers library [False] - interp: Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum] - merge_dest_directory: Specify a directory to store the merged model in [models directory] + Merge diffusers models. The process is controlled by a set parameters provided in the body of the request. + ``` + Argument Description [default] + -------- ---------------------- + keys List of 2-3 model keys to merge together. All models must use the same base type. + merged_model_name Name for the merged model [Concat model names] + alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5] + force If true, force the merge even if the models were generated by different versions of the diffusers library [False] + interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum] + merge_dest_directory Specify a directory to store the merged model in [models directory] + ``` """ logger = ApiDependencies.invoker.services.logger try: diff --git a/invokeai/app/services/download/download_default.py b/invokeai/app/services/download/download_default.py index 7008f8ed741..6d5cedbcad8 100644 --- a/invokeai/app/services/download/download_default.py +++ b/invokeai/app/services/download/download_default.py @@ -13,7 +13,7 @@ import requests from pydantic.networks import AnyHttpUrl from requests import HTTPError -from tqdm import tqdm, std +from tqdm import tqdm from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.util.misc import get_iso_timestamp diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 9f0f774b499..42921f0b32c 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -123,11 +123,11 @@ class ModelRepoVariant(str, Enum): class ModelConfigBase(BaseModel): """Base class for model configuration information.""" - path: str - name: str - base: BaseModelType - type: ModelType - format: ModelFormat + path: str = Field(description="filesystem path to the model file or directory") + name: str = Field(description="model name") + base: BaseModelType = Field(description="base model") + type: ModelType = Field(description="type of the model") + format: ModelFormat = Field(description="model format") key: str = Field(description="unique key for model", default="") original_hash: Optional[str] = Field( description="original fasthash of model contents", default=None @@ -135,9 +135,9 @@ class ModelConfigBase(BaseModel): current_hash: Optional[str] = Field( description="current fasthash of model contents", default=None ) # if model is converted or otherwise modified, this will hold updated hash - description: Optional[str] = Field(default=None) - source: Optional[str] = Field(description="Model download source (URL or repo_id)", default=None) - last_modified: Optional[float] = Field(description="Timestamp for modification time", default_factory=time.time) + description: Optional[str] = Field(description="human readable description of the model", default=None) + source: Optional[str] = Field(description="model original source (path, URL or repo_id)", default=None) + last_modified: Optional[float] = Field(description="timestamp for modification time", default_factory=time.time) model_config = ConfigDict( use_enum_values=False, From 86ac55ab5fbe4cc22e9ec5e095dcd239c447a8c8 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Wed, 14 Feb 2024 09:36:30 -0500 Subject: [PATCH 088/100] Remove references to model_records service, change submodel property on ModelInfo to submodel_type to support new params in model manager --- docs/contributing/MODEL_MANAGER.md | 2 +- invokeai/app/invocations/latent.py | 2 +- invokeai/app/invocations/model.py | 22 +++++++-------- invokeai/app/invocations/sdxl.py | 28 +++++++++---------- .../backend/model_management/model_manager.py | 2 +- pyproject.toml | 2 +- 6 files changed, 29 insertions(+), 29 deletions(-) diff --git a/docs/contributing/MODEL_MANAGER.md b/docs/contributing/MODEL_MANAGER.md index b711c654de8..b19699de73d 100644 --- a/docs/contributing/MODEL_MANAGER.md +++ b/docs/contributing/MODEL_MANAGER.md @@ -1627,7 +1627,7 @@ payload=dict( queue_batch_id=queue_batch_id, graph_execution_state_id=graph_execution_state_id, model_key=model_key, - submodel=submodel, + submodel_type=submodel, hash=model_info.hash, location=str(model_info.location), precision=str(model_info.precision), diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 289da2dd73d..c3de5219406 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -812,7 +812,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: ) with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae: - assert isinstance(vae, torch.Tensor) + assert isinstance(vae, torch.nn.Module) latents = latents.to(vae.device) if self.fp32: vae.to(dtype=torch.float32) diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index f78425c6ee3..71a71a63c83 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -18,7 +18,7 @@ class ModelInfo(BaseModel): key: str = Field(description="Key of model as returned by ModelRecordServiceBase.get_model()") - submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel") + submodel_type: Optional[SubModelType] = Field(default=None, description="Info to load submodel") class LoraInfo(ModelInfo): @@ -110,22 +110,22 @@ def invoke(self, context: InvocationContext) -> ModelLoaderOutput: unet=UNetField( unet=ModelInfo( key=key, - submodel=SubModelType.UNet, + submodel_type=SubModelType.UNet, ), scheduler=ModelInfo( key=key, - submodel=SubModelType.Scheduler, + submodel_type=SubModelType.Scheduler, ), loras=[], ), clip=ClipField( tokenizer=ModelInfo( key=key, - submodel=SubModelType.Tokenizer, + submodel_type=SubModelType.Tokenizer, ), text_encoder=ModelInfo( key=key, - submodel=SubModelType.TextEncoder, + submodel_type=SubModelType.TextEncoder, ), loras=[], skipped_layers=0, @@ -133,7 +133,7 @@ def invoke(self, context: InvocationContext) -> ModelLoaderOutput: vae=VaeField( vae=ModelInfo( key=key, - submodel=SubModelType.Vae, + submodel_type=SubModelType.Vae, ), ), ) @@ -188,7 +188,7 @@ def invoke(self, context: InvocationContext) -> LoraLoaderOutput: output.unet.loras.append( LoraInfo( key=lora_key, - submodel=None, + submodel_type=None, weight=self.weight, ) ) @@ -198,7 +198,7 @@ def invoke(self, context: InvocationContext) -> LoraLoaderOutput: output.clip.loras.append( LoraInfo( key=lora_key, - submodel=None, + submodel_type=None, weight=self.weight, ) ) @@ -271,7 +271,7 @@ def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: output.unet.loras.append( LoraInfo( key=lora_key, - submodel=None, + submodel_type=None, weight=self.weight, ) ) @@ -281,7 +281,7 @@ def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: output.clip.loras.append( LoraInfo( key=lora_key, - submodel=None, + submodel_type=None, weight=self.weight, ) ) @@ -291,7 +291,7 @@ def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: output.clip2.loras.append( LoraInfo( key=lora_key, - submodel=None, + submodel_type=None, weight=self.weight, ) ) diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 633a6477fdb..85e6fb787fa 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -43,29 +43,29 @@ def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput: model_key = self.model.key # TODO: not found exceptions - if not context.services.model_records.exists(model_key): + if not context.services.model_manager.store.exists(model_key): raise Exception(f"Unknown model: {model_key}") return SDXLModelLoaderOutput( unet=UNetField( unet=ModelInfo( key=model_key, - submodel=SubModelType.UNet, + submodel_type=SubModelType.UNet, ), scheduler=ModelInfo( key=model_key, - submodel=SubModelType.Scheduler, + submodel_type=SubModelType.Scheduler, ), loras=[], ), clip=ClipField( tokenizer=ModelInfo( key=model_key, - submodel=SubModelType.Tokenizer, + submodel_type=SubModelType.Tokenizer, ), text_encoder=ModelInfo( key=model_key, - submodel=SubModelType.TextEncoder, + submodel_type=SubModelType.TextEncoder, ), loras=[], skipped_layers=0, @@ -73,11 +73,11 @@ def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput: clip2=ClipField( tokenizer=ModelInfo( key=model_key, - submodel=SubModelType.Tokenizer2, + submodel_type=SubModelType.Tokenizer2, ), text_encoder=ModelInfo( key=model_key, - submodel=SubModelType.TextEncoder2, + submodel_type=SubModelType.TextEncoder2, ), loras=[], skipped_layers=0, @@ -85,7 +85,7 @@ def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput: vae=VaeField( vae=ModelInfo( key=model_key, - submodel=SubModelType.Vae, + submodel_type=SubModelType.Vae, ), ), ) @@ -112,29 +112,29 @@ def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput: model_key = self.model.key # TODO: not found exceptions - if not context.services.model_records.exists(model_key): + if not context.services.model_manager.store.exists(model_key): raise Exception(f"Unknown model: {model_key}") return SDXLRefinerModelLoaderOutput( unet=UNetField( unet=ModelInfo( key=model_key, - submodel=SubModelType.UNet, + submodel_type=SubModelType.UNet, ), scheduler=ModelInfo( key=model_key, - submodel=SubModelType.Scheduler, + submodel_type=SubModelType.Scheduler, ), loras=[], ), clip2=ClipField( tokenizer=ModelInfo( key=model_key, - submodel=SubModelType.Tokenizer2, + submodel_type=SubModelType.Tokenizer2, ), text_encoder=ModelInfo( key=model_key, - submodel=SubModelType.TextEncoder2, + submodel_type=SubModelType.TextEncoder2, ), loras=[], skipped_layers=0, @@ -142,7 +142,7 @@ def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput: vae=VaeField( vae=ModelInfo( key=model_key, - submodel=SubModelType.Vae, + submodel_type=SubModelType.Vae, ), ), ) diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index da74ca3fb58..84d93f15fa8 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -499,7 +499,7 @@ def get_model( model_class=model_class, base_model=base_model, model_type=model_type, - submodel=submodel_type, + submodel_type=submodel_type, ) if model_key not in self.cache_keys: diff --git a/pyproject.toml b/pyproject.toml index 2958e3629a8..f57607bc0af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -245,7 +245,7 @@ module = [ "invokeai.app.services.invocation_stats.invocation_stats_default", "invokeai.app.services.model_manager.model_manager_base", "invokeai.app.services.model_manager.model_manager_default", - "invokeai.app.services.model_records.model_records_sql", + "invokeai.app.services.model_manager.store.model_records_sql", "invokeai.app.util.controlnet_utils", "invokeai.backend.image_util.txt2mask", "invokeai.backend.image_util.safety_checker", From 0b1c2acd617a0b4f90c8d52f7236cbb7f19becb2 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Wed, 14 Feb 2024 09:51:11 -0500 Subject: [PATCH 089/100] References to context.services.model_manager.store.get_model can only accept keys, remove invalid assertion --- invokeai/app/invocations/latent.py | 4 ++-- .../load/model_cache/model_cache_default.py | 22 +++++++++---------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index c3de5219406..05293fdfee3 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -681,7 +681,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: source_node_id = graph_execution_state.prepared_source_mapping[self.id] # get the unet's config so that we can pass the base to dispatch_progress() - unet_config = context.services.model_manager.store.get_model(**self.unet.unet.model_dump()) + unet_config = context.services.model_manager.store.get_model(self.unet.unet.key) def step_callback(state: PipelineIntermediateState) -> None: self.dispatch_progress(context, source_node_id, state, unet_config.base) @@ -709,7 +709,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: # Apply the LoRA after unet has been moved to its target device for faster patching. ModelPatcher.apply_lora_unet(unet, _lora_loader()), ): - assert isinstance(unet, torch.Tensor) + assert isinstance(unet, UNet2DConditionModel) latents = latents.to(device=unet.device, dtype=unet.dtype) if noise is not None: noise = noise.to(device=unet.device, dtype=unet.dtype) diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index 786396062cf..02ce1266c75 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -303,18 +303,18 @@ def print_cuda_stats(self) -> None: in_vram_models = 0 locked_in_vram_models = 0 for cache_record in self._cached_models.values(): - assert hasattr(cache_record.model, "device") - if cache_record.model.device == self.storage_device: - in_ram_models += 1 - else: - in_vram_models += 1 - if cache_record.locked: - locked_in_vram_models += 1 + if hasattr(cache_record.model, "device"): + if cache_record.model.device == self.storage_device: + in_ram_models += 1 + else: + in_vram_models += 1 + if cache_record.locked: + locked_in_vram_models += 1 - self.logger.debug( - f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) =" - f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})" - ) + self.logger.debug( + f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) =" + f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})" + ) def make_room(self, model_size: int) -> None: """Make enough room in the cache to accommodate a new model of indicated size.""" From 8e5139291071608e3c3b33eee42b0f969212f7c3 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Wed, 14 Feb 2024 13:07:11 -0500 Subject: [PATCH 090/100] Update _get_hf_load_class to support clipvision models --- invokeai/backend/model_manager/load/load_default.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index df83c8320d9..9ed0ccb2d3c 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -163,8 +163,12 @@ def _get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelT else: try: config = self._load_diffusers_config(model_path, config_name="config.json") - class_name = config["_class_name"] - return self._hf_definition_to_type(module="diffusers", class_name=class_name) + class_name = config.get("_class_name", None) + if class_name: + return self._hf_definition_to_type(module="diffusers", class_name=class_name) + if config.get("model_type", None) == "clip_vision_model": + class_name = config.get("architectures")[0] + return self._hf_definition_to_type(module="transformers", class_name=class_name) except KeyError as e: raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e From 20e6d4fa3c50a88642beb058bf676d48e42b7f85 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Wed, 14 Feb 2024 13:16:15 -0500 Subject: [PATCH 091/100] Raise InvalidModelConfigException when unable to detect load class in ModelLoader --- invokeai/backend/model_manager/load/load_default.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index 9ed0ccb2d3c..1dac121a300 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -169,6 +169,8 @@ def _get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelT if config.get("model_type", None) == "clip_vision_model": class_name = config.get("architectures")[0] return self._hf_definition_to_type(module="transformers", class_name=class_name) + if not class_name: + raise InvalidModelConfigException("Unable to decifer Load Class based on given config.json") except KeyError as e: raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e From 4c6bcdbc18ac9e970f73bea0b9516af281a2d020 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 15 Feb 2024 20:43:41 +1100 Subject: [PATCH 092/100] feat(nodes): update invocation context for mm2, update nodes model usage --- invokeai/app/invocations/compel.py | 40 ++------ invokeai/app/invocations/ip_adapter.py | 7 +- invokeai/app/invocations/latent.py | 71 +++----------- invokeai/app/invocations/model.py | 8 +- invokeai/app/invocations/sdxl.py | 4 +- .../services/model_load/model_load_base.py | 14 +-- .../services/model_load/model_load_default.py | 48 +++++----- .../app/services/shared/invocation_context.py | 94 ++++++++++++++----- invokeai/app/util/step_callback.py | 2 +- 9 files changed, 141 insertions(+), 147 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 3850fb6cc3d..5159d5b89c5 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -69,20 +69,12 @@ class CompelInvocation(BaseInvocation): @torch.no_grad() def invoke(self, context: InvocationContext) -> ConditioningOutput: - tokenizer_info = context.services.model_manager.load.load_model_by_key( - **self.clip.tokenizer.model_dump(), - context=context, - ) - text_encoder_info = context.services.model_manager.load.load_model_by_key( - **self.clip.text_encoder.model_dump(), - context=context, - ) + tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump()) + text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump()) def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in self.clip.loras: - lora_info = context.services.model_manager.load.load_model_by_key( - **lora.model_dump(exclude={"weight"}), context=context - ) + lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) assert isinstance(lora_info.model, LoRAModelRaw) yield (lora_info.model, lora.weight) del lora_info @@ -94,10 +86,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for trigger in extract_ti_triggers_from_prompt(self.prompt): name = trigger[1:-1] try: - loaded_model = context.services.model_manager.load.load_model_by_key( - **self.clip.text_encoder.model_dump(), - context=context, - ).model + loaded_model = context.models.load(**self.clip.text_encoder.model_dump()).model assert isinstance(loaded_model, TextualInversionModelRaw) ti_list.append((name, loaded_model)) except UnknownModelException: @@ -165,14 +154,8 @@ def run_clip_compel( lora_prefix: str, zero_on_empty: bool, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]: - tokenizer_info = context.services.model_manager.load.load_model_by_key( - **clip_field.tokenizer.model_dump(), - context=context, - ) - text_encoder_info = context.services.model_manager.load.load_model_by_key( - **clip_field.text_encoder.model_dump(), - context=context, - ) + tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump()) + text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump()) # return zero on empty if prompt == "" and zero_on_empty: @@ -197,9 +180,7 @@ def run_clip_compel( def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in clip_field.loras: - lora_info = context.services.model_manager.load.load_model_by_key( - **lora.model_dump(exclude={"weight"}), context=context - ) + lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) lora_model = lora_info.model assert isinstance(lora_model, LoRAModelRaw) yield (lora_model, lora.weight) @@ -212,11 +193,8 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for trigger in extract_ti_triggers_from_prompt(prompt): name = trigger[1:-1] try: - ti_model = context.services.model_manager.load.load_model_by_attr( - model_name=name, - base_model=text_encoder_info.config.base, - model_type=ModelType.TextualInversion, - context=context, + ti_model = context.models.load_by_attrs( + model_name=name, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion ).model assert isinstance(ti_model, TextualInversionModelRaw) ti_list.append((name, ti_model)) diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index 01124f62f3c..15e254010b5 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -14,8 +14,7 @@ from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.backend.model_management.models.base import BaseModelType, ModelType -from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id +from invokeai.backend.model_manager.config import BaseModelType, ModelType # LS: Consider moving these two classes into model.py @@ -90,10 +89,10 @@ def validate_begin_end_step_percent(self) -> Self: def invoke(self, context: InvocationContext) -> IPAdapterOutput: # Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model. - ip_adapter_info = context.services.model_manager.store.get_model(self.ip_adapter_model.key) + ip_adapter_info = context.models.get_config(self.ip_adapter_model.key) image_encoder_model_id = ip_adapter_info.image_encoder_model_id image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip() - image_encoder_models = context.services.model_manager.store.search_by_attr( + image_encoder_models = context.models.search_by_attrs( model_name=image_encoder_model_name, base_model=BaseModelType.Any, model_type=ModelType.CLIPVision ) assert len(image_encoder_models) == 1 diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 05293fdfee3..5dd0eb074d5 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -141,7 +141,7 @@ def prep_mask_tensor(self, mask_image: Image) -> torch.Tensor: @torch.no_grad() def invoke(self, context: InvocationContext) -> DenoiseMaskOutput: if self.image is not None: - image = context.services.images.get_pil_image(self.image.image_name) + image = context.images.get_pil(self.image.image_name) image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) if image_tensor.dim() == 3: image_tensor = image_tensor.unsqueeze(0) @@ -153,10 +153,7 @@ def invoke(self, context: InvocationContext) -> DenoiseMaskOutput: ) if image_tensor is not None: - vae_info = context.services.model_manager.load.load_model_by_key( - **self.vae.vae.model_dump(), - context=context, - ) + vae_info = context.models.load(**self.vae.vae.model_dump()) img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0) @@ -182,10 +179,7 @@ def get_scheduler( seed: int, ) -> Scheduler: scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"]) - orig_scheduler_info = context.services.model_manager.load.load_model_by_key( - **scheduler_info.model_dump(), - context=context, - ) + orig_scheduler_info = context.models.load(**scheduler_info.model_dump()) with orig_scheduler_info as orig_scheduler: scheduler_config = orig_scheduler.config @@ -399,12 +393,7 @@ def prep_control_data( # and if weight is None, populate with default 1.0? controlnet_data = [] for control_info in control_list: - control_model = exit_stack.enter_context( - context.services.model_manager.load.load_model_by_key( - key=control_info.control_model.key, - context=context, - ) - ) + control_model = exit_stack.enter_context(context.models.load(key=control_info.control_model.key)) # control_models.append(control_model) control_image_field = control_info.image @@ -466,25 +455,17 @@ def prep_ip_adapter_data( conditioning_data.ip_adapter_conditioning = [] for single_ip_adapter in ip_adapter: ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context( - context.services.model_manager.load.load_model_by_key( - key=single_ip_adapter.ip_adapter_model.key, - context=context, - ) + context.models.load(key=single_ip_adapter.ip_adapter_model.key) ) - image_encoder_model_info = context.services.model_manager.load.load_model_by_key( - key=single_ip_adapter.image_encoder_model.key, - context=context, - ) + image_encoder_model_info = context.models.load(key=single_ip_adapter.image_encoder_model.key) # `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here. single_ipa_image_fields = single_ip_adapter.image if not isinstance(single_ipa_image_fields, list): single_ipa_image_fields = [single_ipa_image_fields] - single_ipa_images = [ - context.services.images.get_pil_image(image.image_name) for image in single_ipa_image_fields - ] + single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_image_fields] # TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other # models are needed in memory. This would help to reduce peak memory utilization in low-memory environments. @@ -528,10 +509,7 @@ def run_t2i_adapters( t2i_adapter_data = [] for t2i_adapter_field in t2i_adapter: - t2i_adapter_model_info = context.services.model_manager.load.load_model_by_key( - key=t2i_adapter_field.t2i_adapter_model.key, - context=context, - ) + t2i_adapter_model_info = context.models.load(key=t2i_adapter_field.t2i_adapter_model.key) image = context.images.get_pil(t2i_adapter_field.image.image_name) # The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally. @@ -676,30 +654,20 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: do_classifier_free_guidance=True, ) - # Get the source node id (we are invoking the prepared node) - graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) - source_node_id = graph_execution_state.prepared_source_mapping[self.id] - # get the unet's config so that we can pass the base to dispatch_progress() - unet_config = context.services.model_manager.store.get_model(self.unet.unet.key) + unet_config = context.models.get_config(self.unet.unet.key) def step_callback(state: PipelineIntermediateState) -> None: - self.dispatch_progress(context, source_node_id, state, unet_config.base) + context.util.sd_step_callback(state, unet_config.base) def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in self.unet.loras: - lora_info = context.services.model_manager.load.load_model_by_key( - **lora.model_dump(exclude={"weight"}), - context=context, - ) + lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) yield (lora_info.model, lora.weight) del lora_info return - unet_info = context.services.model_manager.load.load_model_by_key( - **self.unet.unet.model_dump(), - context=context, - ) + unet_info = context.models.load(**self.unet.unet.model_dump()) assert isinstance(unet_info.model, UNet2DConditionModel) with ( ExitStack() as exit_stack, @@ -806,10 +774,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): def invoke(self, context: InvocationContext) -> ImageOutput: latents = context.tensors.load(self.latents.latents_name) - vae_info = context.services.model_manager.load.load_model_by_key( - **self.vae.vae.model_dump(), - context=context, - ) + vae_info = context.models.load(**self.vae.vae.model_dump()) with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae: assert isinstance(vae, torch.nn.Module) @@ -1032,10 +997,7 @@ def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: t def invoke(self, context: InvocationContext) -> LatentsOutput: image = context.images.get_pil(self.image.image_name) - vae_info = context.services.model_manager.load.load_model_by_key( - **self.vae.vae.model_dump(), - context=context, - ) + vae_info = context.models.load(**self.vae.vae.model_dump()) image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) if image_tensor.dim() == 3: @@ -1239,10 +1201,7 @@ def trim_to_multiple_of(self, *args: int, multiple_of: int = LATENT_SCALE_FACTOR return tuple((x - x % multiple_of) for x in args) def invoke(self, context: InvocationContext) -> IdealSizeOutput: - unet_config = context.services.model_manager.load.load_model_by_key( - **self.unet.unet.model_dump(), - context=context, - ) + unet_config = context.models.get_config(**self.unet.unet.model_dump()) aspect = self.width / self.height dimension: float = 512 if unet_config.base == BaseModelType.StableDiffusion2: diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 71a71a63c83..6087bc82db1 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -103,7 +103,7 @@ def invoke(self, context: InvocationContext) -> ModelLoaderOutput: key = self.model.key # TODO: not found exceptions - if not context.services.model_manager.store.exists(key): + if not context.models.exists(key): raise Exception(f"Unknown model {key}") return ModelLoaderOutput( @@ -172,7 +172,7 @@ def invoke(self, context: InvocationContext) -> LoraLoaderOutput: lora_key = self.lora.key - if not context.services.model_manager.store.exists(lora_key): + if not context.models.exists(lora_key): raise Exception(f"Unkown lora: {lora_key}!") if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras): @@ -252,7 +252,7 @@ def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: lora_key = self.lora.key - if not context.services.model_manager.store.exists(lora_key): + if not context.models.exists(lora_key): raise Exception(f"Unknown lora: {lora_key}!") if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras): @@ -318,7 +318,7 @@ class VaeLoaderInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> VAEOutput: key = self.vae_model.key - if not context.services.model_manager.store.exists(key): + if not context.models.exists(key): raise Exception(f"Unkown vae: {key}!") return VAEOutput(vae=VaeField(vae=ModelInfo(key=key))) diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 85e6fb787fa..0df27c00110 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -43,7 +43,7 @@ def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput: model_key = self.model.key # TODO: not found exceptions - if not context.services.model_manager.store.exists(model_key): + if not context.models.exists(model_key): raise Exception(f"Unknown model: {model_key}") return SDXLModelLoaderOutput( @@ -112,7 +112,7 @@ def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput: model_key = self.model.key # TODO: not found exceptions - if not context.services.model_manager.store.exists(model_key): + if not context.models.exists(model_key): raise Exception(f"Unknown model: {model_key}") return SDXLRefinerModelLoaderOutput( diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py index 45eaf4652fb..f4dd905135a 100644 --- a/invokeai/app/services/model_load/model_load_base.py +++ b/invokeai/app/services/model_load/model_load_base.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from typing import Optional -from invokeai.app.invocations.baseinvocation import InvocationContext +from invokeai.app.services.shared.invocation_context import InvocationContextData from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType from invokeai.backend.model_manager.load import LoadedModel from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase @@ -19,14 +19,14 @@ def load_model_by_key( self, key: str, submodel_type: Optional[SubModelType] = None, - context: Optional[InvocationContext] = None, + context_data: Optional[InvocationContextData] = None, ) -> LoadedModel: """ Given a model's key, load it and return the LoadedModel object. :param key: Key of model config to be fetched. :param submodel: For main (pipeline models), the submodel to fetch. - :param context: Invocation context used for event reporting + :param context_data: Invocation context data used for event reporting """ pass @@ -35,14 +35,14 @@ def load_model_by_config( self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, - context: Optional[InvocationContext] = None, + context_data: Optional[InvocationContextData] = None, ) -> LoadedModel: """ Given a model's configuration, load it and return the LoadedModel object. :param model_config: Model configuration record (as returned by ModelRecordBase.get_model()) :param submodel: For main (pipeline models), the submodel to fetch. - :param context: Invocation context used for event reporting + :param context_data: Invocation context data used for event reporting """ pass @@ -53,7 +53,7 @@ def load_model_by_attr( base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None, - context: Optional[InvocationContext] = None, + context_data: Optional[InvocationContextData] = None, ) -> LoadedModel: """ Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object. @@ -66,7 +66,7 @@ def load_model_by_attr( :param base_model: Base model :param model_type: Type of the model :param submodel: For main (pipeline models), the submodel to fetch - :param context: The invocation context. + :param context_data: The invocation context data. Exceptions: UnknownModelException -- model with these attributes not known NotImplementedException -- a model loader was not provided at initialization time diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index a6ccd5afbc3..29b297c8145 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -3,10 +3,11 @@ from typing import Optional -from invokeai.app.invocations.baseinvocation import InvocationContext from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException +from invokeai.app.services.invoker import Invoker from invokeai.app.services.model_records import ModelRecordServiceBase, UnknownModelException +from invokeai.app.services.shared.invocation_context import InvocationContextData from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel, ModelCache, ModelConvertCache from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase @@ -46,6 +47,9 @@ def __init__( ), ) + def start(self, invoker: Invoker) -> None: + self._invoker = invoker + @property def ram_cache(self) -> ModelCacheBase[AnyModel]: """Return the RAM cache used by this loader.""" @@ -60,7 +64,7 @@ def load_model_by_key( self, key: str, submodel_type: Optional[SubModelType] = None, - context: Optional[InvocationContext] = None, + context_data: Optional[InvocationContextData] = None, ) -> LoadedModel: """ Given a model's key, load it and return the LoadedModel object. @@ -70,7 +74,7 @@ def load_model_by_key( :param context: Invocation context used for event reporting """ config = self._store.get_model(key) - return self.load_model_by_config(config, submodel_type, context) + return self.load_model_by_config(config, submodel_type, context_data) def load_model_by_attr( self, @@ -78,7 +82,7 @@ def load_model_by_attr( base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None, - context: Optional[InvocationContext] = None, + context_data: Optional[InvocationContextData] = None, ) -> LoadedModel: """ Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object. @@ -109,7 +113,7 @@ def load_model_by_config( self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, - context: Optional[InvocationContext] = None, + context_data: Optional[InvocationContextData] = None, ) -> LoadedModel: """ Given a model's configuration, load it and return the LoadedModel object. @@ -118,15 +122,15 @@ def load_model_by_config( :param submodel: For main (pipeline models), the submodel to fetch. :param context: Invocation context used for event reporting """ - if context: + if context_data: self._emit_load_event( - context=context, + context_data=context_data, model_config=model_config, ) loaded_model = self._any_loader.load_model(model_config, submodel_type) - if context: + if context_data: self._emit_load_event( - context=context, + context_data=context_data, model_config=model_config, loaded=True, ) @@ -134,26 +138,28 @@ def load_model_by_config( def _emit_load_event( self, - context: InvocationContext, + context_data: InvocationContextData, model_config: AnyModelConfig, loaded: Optional[bool] = False, ) -> None: - if context.services.queue.is_canceled(context.graph_execution_state_id): + if not self._invoker: + return + if self._invoker.services.queue.is_canceled(context_data.session_id): raise CanceledException() if not loaded: - context.services.events.emit_model_load_started( - queue_id=context.queue_id, - queue_item_id=context.queue_item_id, - queue_batch_id=context.queue_batch_id, - graph_execution_state_id=context.graph_execution_state_id, + self._invoker.services.events.emit_model_load_started( + queue_id=context_data.queue_id, + queue_item_id=context_data.queue_item_id, + queue_batch_id=context_data.batch_id, + graph_execution_state_id=context_data.session_id, model_config=model_config, ) else: - context.services.events.emit_model_load_completed( - queue_id=context.queue_id, - queue_item_id=context.queue_item_id, - queue_batch_id=context.queue_batch_id, - graph_execution_state_id=context.graph_execution_state_id, + self._invoker.services.events.emit_model_load_completed( + queue_id=context_data.queue_id, + queue_item_id=context_data.queue_item_id, + queue_batch_id=context_data.batch_id, + graph_execution_state_id=context_data.session_id, model_config=model_config, ) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index c68dc1140b2..089d09f825c 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from pathlib import Path from typing import TYPE_CHECKING, Optional from PIL.Image import Image @@ -12,8 +13,9 @@ from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID from invokeai.app.util.step_callback import stable_diffusion_step_callback -from invokeai.backend.model_management.model_manager import LoadedModelInfo -from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType +from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType +from invokeai.backend.model_manager.load.load_base import LoadedModel +from invokeai.backend.model_manager.metadata.metadata_base import AnyModelRepoMetadata from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData @@ -259,45 +261,95 @@ def load(self, name: str) -> ConditioningFieldData: class ModelsInterface(InvocationContextInterface): - def exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> bool: + def exists(self, key: str) -> bool: """ Checks if a model exists. - :param model_name: The name of the model to check. - :param base_model: The base model of the model to check. - :param model_type: The type of the model to check. + :param key: The key of the model. """ - return self._services.model_manager.model_exists(model_name, base_model, model_type) + return self._services.model_manager.store.exists(key) - def load( - self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None - ) -> LoadedModelInfo: + def load(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel: """ Loads a model. - :param model_name: The name of the model to get. - :param base_model: The base model of the model to get. - :param model_type: The type of the model to get. - :param submodel: The submodel of the model to get. + :param key: The key of the model. + :param submodel_type: The submodel of the model to get. :returns: An object representing the loaded model. """ # The model manager emits events as it loads the model. It needs the context data to build # the event payloads. - return self._services.model_manager.get_model( - model_name, base_model, model_type, submodel, context_data=self._context_data + return self._services.model_manager.load.load_model_by_key( + key=key, submodel_type=submodel_type, context_data=self._context_data ) - def get_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: + def load_by_attrs( + self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None + ) -> LoadedModel: + """ + Loads a model by its attributes. + + :param model_name: Name of to be fetched. + :param base_model: Base model + :param model_type: Type of the model + :param submodel: For main (pipeline models), the submodel to fetch + """ + return self._services.model_manager.load.load_model_by_attr( + model_name=model_name, + base_model=base_model, + model_type=model_type, + submodel=submodel, + context_data=self._context_data, + ) + + def get_config(self, key: str) -> AnyModelConfig: """ Gets a model's info, an dict-like object. - :param model_name: The name of the model to get. - :param base_model: The base model of the model to get. - :param model_type: The type of the model to get. + :param key: The key of the model. + """ + return self._services.model_manager.store.get_model(key=key) + + def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]: + """ + Gets a model's metadata, if it has any. + + :param key: The key of the model. """ - return self._services.model_manager.model_info(model_name, base_model, model_type) + return self._services.model_manager.store.get_metadata(key=key) + + def search_by_path(self, path: Path) -> list[AnyModelConfig]: + """ + Searches for models by path. + + :param path: The path to search for. + """ + return self._services.model_manager.store.search_by_path(path) + + def search_by_attrs( + self, + model_name: Optional[str] = None, + base_model: Optional[BaseModelType] = None, + model_type: Optional[ModelType] = None, + model_format: Optional[ModelFormat] = None, + ) -> list[AnyModelConfig]: + """ + Searches for models by attributes. + + :param model_name: Name of to be fetched. + :param base_model: Base model + :param model_type: Type of the model + :param submodel: For main (pipeline models), the submodel to fetch + """ + + return self._services.model_manager.store.search_by_attr( + model_name=model_name, + base_model=base_model, + model_type=model_type, + model_format=model_format, + ) class ConfigInterface(InvocationContextInterface): diff --git a/invokeai/app/util/step_callback.py b/invokeai/app/util/step_callback.py index d83b380d95d..33d00ca3660 100644 --- a/invokeai/app/util/step_callback.py +++ b/invokeai/app/util/step_callback.py @@ -4,8 +4,8 @@ from PIL import Image from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException, ProgressImage +from invokeai.backend.model_manager.config import BaseModelType -from ...backend.model_management.models import BaseModelType from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.util.util import image_to_dataURL From 6e5e9176c09fd1c4326ac7a3eccd863e1bb87397 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 15 Feb 2024 20:50:47 +1100 Subject: [PATCH 093/100] chore: ruff --- invokeai/app/invocations/controlnet_image_processors.py | 1 - 1 file changed, 1 deletion(-) diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 580ee085627..8542134fff0 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -39,7 +39,6 @@ from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.image_util.depth_anything import DepthAnythingDetector from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector -from invokeai.backend.model_management.models.base import BaseModelType from .baseinvocation import ( BaseInvocation, From ed434725820961e60056dc5ecaf4b473b3f6b153 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 15 Feb 2024 20:52:44 +1100 Subject: [PATCH 094/100] chore: lint --- invokeai/frontend/web/src/features/nodes/types/error.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/features/nodes/types/error.ts b/invokeai/frontend/web/src/features/nodes/types/error.ts index c3da136c7a8..82bc0f86e09 100644 --- a/invokeai/frontend/web/src/features/nodes/types/error.ts +++ b/invokeai/frontend/web/src/features/nodes/types/error.ts @@ -60,4 +60,4 @@ export class FieldParseError extends Error { export class UnableToExtractSchemaNameFromRefError extends FieldParseError {} export class UnsupportedArrayItemType extends FieldParseError {} export class UnsupportedUnionError extends FieldParseError {} -export class UnsupportedPrimitiveTypeError extends FieldParseError {} \ No newline at end of file +export class UnsupportedPrimitiveTypeError extends FieldParseError {} From 2bd1ab2f1cde19a364946b4000b96afd1e191718 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 15 Feb 2024 20:53:41 +1100 Subject: [PATCH 095/100] fix(ui): fix type issues --- .../nodes/components/sidePanel/viewMode/WorkflowField.tsx | 4 ++-- .../src/features/nodes/util/schema/parseFieldType.test.ts | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/viewMode/WorkflowField.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/viewMode/WorkflowField.tsx index 0e5857933a7..e707dd4f54d 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/viewMode/WorkflowField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/viewMode/WorkflowField.tsx @@ -16,7 +16,7 @@ type Props = { const WorkflowField = ({ nodeId, fieldName }: Props) => { const label = useFieldLabel(nodeId, fieldName); - const fieldTemplateTitle = useFieldTemplateTitle(nodeId, fieldName, 'input'); + const fieldTemplateTitle = useFieldTemplateTitle(nodeId, fieldName, 'inputs'); const { isValueChanged, onReset } = useFieldOriginalValue(nodeId, fieldName); return ( @@ -36,7 +36,7 @@ const WorkflowField = ({ nodeId, fieldName }: Props) => { /> )} } + label={} openDelay={HANDLE_TOOLTIP_OPEN_DELAY} placement="top" > diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.test.ts b/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.test.ts index 2f4ce48a326..d7011ad6f84 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.test.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.test.ts @@ -284,13 +284,13 @@ describe('refObjectToSchemaName', async () => { }); describe.concurrent('parseFieldType', async () => { - it.each(primitiveTypes)('parses primitive types ($name)', ({ schema, expected }) => { + it.each(primitiveTypes)('parses primitive types ($name)', ({ schema, expected }: ParseFieldTypeTestCase) => { expect(parseFieldType(schema)).toEqual(expected); }); - it.each(complexTypes)('parses complex types ($name)', ({ schema, expected }) => { + it.each(complexTypes)('parses complex types ($name)', ({ schema, expected }: ParseFieldTypeTestCase) => { expect(parseFieldType(schema)).toEqual(expected); }); - it.each(specialCases)('parses special case types ($name)', ({ schema, expected }) => { + it.each(specialCases)('parses special case types ($name)', ({ schema, expected }: ParseFieldTypeTestCase) => { expect(parseFieldType(schema)).toEqual(expected); }); From 560ae17e215083c08b6fe339ed338c4b03d22586 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 15 Feb 2024 21:16:25 +1100 Subject: [PATCH 096/100] feat(ui): export components type --- .../frontend/web/src/services/api/types.ts | 228 +++++++++--------- 1 file changed, 114 insertions(+), 114 deletions(-) diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index 55ff808b404..f9a1decf655 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -3,7 +3,7 @@ import type { EntityState } from '@reduxjs/toolkit'; import type { components, paths } from 'services/api/schema'; import type { O } from 'ts-toolbelt'; -type s = components['schemas']; +export type S = components['schemas']; export type ImageCache = EntityState; @@ -23,60 +23,60 @@ export type BatchConfig = export type EnqueueBatchResult = components['schemas']['EnqueueBatchResult']; -export type InputFieldJSONSchemaExtra = s['InputFieldJSONSchemaExtra']; -export type OutputFieldJSONSchemaExtra = s['OutputFieldJSONSchemaExtra']; -export type InvocationJSONSchemaExtra = s['UIConfigBase']; +export type InputFieldJSONSchemaExtra = S['InputFieldJSONSchemaExtra']; +export type OutputFieldJSONSchemaExtra = S['OutputFieldJSONSchemaExtra']; +export type InvocationJSONSchemaExtra = S['UIConfigBase']; // App Info -export type AppVersion = s['AppVersion']; -export type AppConfig = s['AppConfig']; -export type AppDependencyVersions = s['AppDependencyVersions']; +export type AppVersion = S['AppVersion']; +export type AppConfig = S['AppConfig']; +export type AppDependencyVersions = S['AppDependencyVersions']; // Images -export type ImageDTO = s['ImageDTO']; -export type BoardDTO = s['BoardDTO']; -export type BoardChanges = s['BoardChanges']; -export type ImageChanges = s['ImageRecordChanges']; -export type ImageCategory = s['ImageCategory']; -export type ResourceOrigin = s['ResourceOrigin']; -export type ImageField = s['ImageField']; -export type OffsetPaginatedResults_BoardDTO_ = s['OffsetPaginatedResults_BoardDTO_']; -export type OffsetPaginatedResults_ImageDTO_ = s['OffsetPaginatedResults_ImageDTO_']; +export type ImageDTO = S['ImageDTO']; +export type BoardDTO = S['BoardDTO']; +export type BoardChanges = S['BoardChanges']; +export type ImageChanges = S['ImageRecordChanges']; +export type ImageCategory = S['ImageCategory']; +export type ResourceOrigin = S['ResourceOrigin']; +export type ImageField = S['ImageField']; +export type OffsetPaginatedResults_BoardDTO_ = S['OffsetPaginatedResults_BoardDTO_']; +export type OffsetPaginatedResults_ImageDTO_ = S['OffsetPaginatedResults_ImageDTO_']; // Models -export type ModelType = s['invokeai__backend__model_management__models__base__ModelType']; -export type SubModelType = s['SubModelType']; -export type BaseModelType = s['invokeai__backend__model_management__models__base__BaseModelType']; -export type MainModelField = s['MainModelField']; -export type VAEModelField = s['VAEModelField']; -export type LoRAModelField = s['LoRAModelField']; -export type LoRAModelFormat = s['LoRAModelFormat']; -export type ControlNetModelField = s['ControlNetModelField']; -export type IPAdapterModelField = s['IPAdapterModelField']; -export type T2IAdapterModelField = s['T2IAdapterModelField']; -export type ModelsList = s['invokeai__app__api__routers__models__ModelsList']; -export type ControlField = s['ControlField']; -export type IPAdapterField = s['IPAdapterField']; +export type ModelType = S['ModelType']; +export type SubModelType = S['SubModelType']; +export type BaseModelType = S['BaseModelType']; +export type MainModelField = S['MainModelField']; +export type VAEModelField = S['VAEModelField']; +export type LoRAModelField = S['LoRAModelField']; +export type LoRAModelFormat = S['LoRAModelFormat']; +export type ControlNetModelField = S['ControlNetModelField']; +export type IPAdapterModelField = S['IPAdapterModelField']; +export type T2IAdapterModelField = S['T2IAdapterModelField']; +export type ModelsList = S['invokeai__app__api__routers__models__ModelsList']; +export type ControlField = S['ControlField']; +export type IPAdapterField = S['IPAdapterField']; // Model Configs -export type LoRAModelConfig = s['LoRAModelConfig']; -export type VaeModelConfig = s['VaeModelConfig']; -export type ControlNetModelCheckpointConfig = s['ControlNetModelCheckpointConfig']; -export type ControlNetModelDiffusersConfig = s['ControlNetModelDiffusersConfig']; +export type LoRAModelConfig = S['LoRAModelConfig']; +export type VaeModelConfig = S['VaeModelConfig']; +export type ControlNetModelCheckpointConfig = S['ControlNetModelCheckpointConfig']; +export type ControlNetModelDiffusersConfig = S['ControlNetModelDiffusersConfig']; export type ControlNetModelConfig = ControlNetModelCheckpointConfig | ControlNetModelDiffusersConfig; -export type IPAdapterModelInvokeAIConfig = s['IPAdapterModelInvokeAIConfig']; +export type IPAdapterModelInvokeAIConfig = S['IPAdapterModelInvokeAIConfig']; export type IPAdapterModelConfig = IPAdapterModelInvokeAIConfig; -export type T2IAdapterModelDiffusersConfig = s['T2IAdapterModelDiffusersConfig']; +export type T2IAdapterModelDiffusersConfig = S['T2IAdapterModelDiffusersConfig']; export type T2IAdapterModelConfig = T2IAdapterModelDiffusersConfig; -export type TextualInversionModelConfig = s['TextualInversionModelConfig']; +export type TextualInversionModelConfig = S['TextualInversionModelConfig']; export type DiffusersModelConfig = - | s['StableDiffusion1ModelDiffusersConfig'] - | s['StableDiffusion2ModelDiffusersConfig'] - | s['StableDiffusionXLModelDiffusersConfig']; + | S['StableDiffusion1ModelDiffusersConfig'] + | S['StableDiffusion2ModelDiffusersConfig'] + | S['StableDiffusionXLModelDiffusersConfig']; export type CheckpointModelConfig = - | s['StableDiffusion1ModelCheckpointConfig'] - | s['StableDiffusion2ModelCheckpointConfig'] - | s['StableDiffusionXLModelCheckpointConfig']; + | S['StableDiffusion1ModelCheckpointConfig'] + | S['StableDiffusion2ModelCheckpointConfig'] + | S['StableDiffusionXLModelCheckpointConfig']; export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig; export type AnyModelConfig = | LoRAModelConfig @@ -87,87 +87,87 @@ export type AnyModelConfig = | TextualInversionModelConfig | MainModelConfig; -export type MergeModelConfig = s['Body_merge_models']; -export type ImportModelConfig = s['Body_import_model']; +export type MergeModelConfig = S['Body_merge_models']; +export type ImportModelConfig = S['Body_import_model']; // Graphs -export type Graph = s['Graph']; +export type Graph = S['Graph']; export type NonNullableGraph = O.Required; -export type Edge = s['Edge']; -export type GraphExecutionState = s['GraphExecutionState']; -export type Batch = s['Batch']; -export type SessionQueueItemDTO = s['SessionQueueItemDTO']; -export type SessionQueueItem = s['SessionQueueItem']; -export type WorkflowRecordOrderBy = s['WorkflowRecordOrderBy']; -export type SQLiteDirection = s['SQLiteDirection']; -export type WorkflowDTO = s['WorkflowRecordDTO']; -export type WorkflowRecordListItemDTO = s['WorkflowRecordListItemDTO']; +export type Edge = S['Edge']; +export type GraphExecutionState = S['GraphExecutionState']; +export type Batch = S['Batch']; +export type SessionQueueItemDTO = S['SessionQueueItemDTO']; +export type SessionQueueItem = S['SessionQueueItem']; +export type WorkflowRecordOrderBy = S['WorkflowRecordOrderBy']; +export type SQLiteDirection = S['SQLiteDirection']; +export type WorkflowDTO = S['WorkflowRecordDTO']; +export type WorkflowRecordListItemDTO = S['WorkflowRecordListItemDTO']; // General nodes -export type CollectInvocation = s['CollectInvocation']; -export type IterateInvocation = s['IterateInvocation']; -export type RangeInvocation = s['RangeInvocation']; -export type RandomRangeInvocation = s['RandomRangeInvocation']; -export type RangeOfSizeInvocation = s['RangeOfSizeInvocation']; -export type ImageResizeInvocation = s['ImageResizeInvocation']; -export type ImageBlurInvocation = s['ImageBlurInvocation']; -export type ImageScaleInvocation = s['ImageScaleInvocation']; -export type InfillPatchMatchInvocation = s['InfillPatchMatchInvocation']; -export type InfillTileInvocation = s['InfillTileInvocation']; -export type CreateDenoiseMaskInvocation = s['CreateDenoiseMaskInvocation']; -export type MaskEdgeInvocation = s['MaskEdgeInvocation']; -export type RandomIntInvocation = s['RandomIntInvocation']; -export type CompelInvocation = s['CompelInvocation']; -export type DynamicPromptInvocation = s['DynamicPromptInvocation']; -export type NoiseInvocation = s['NoiseInvocation']; -export type DenoiseLatentsInvocation = s['DenoiseLatentsInvocation']; -export type SDXLLoraLoaderInvocation = s['SDXLLoraLoaderInvocation']; -export type ImageToLatentsInvocation = s['ImageToLatentsInvocation']; -export type LatentsToImageInvocation = s['LatentsToImageInvocation']; -export type ImageCollectionInvocation = s['ImageCollectionInvocation']; -export type MainModelLoaderInvocation = s['MainModelLoaderInvocation']; -export type LoraLoaderInvocation = s['LoraLoaderInvocation']; -export type ESRGANInvocation = s['ESRGANInvocation']; -export type DivideInvocation = s['DivideInvocation']; -export type ImageNSFWBlurInvocation = s['ImageNSFWBlurInvocation']; -export type ImageWatermarkInvocation = s['ImageWatermarkInvocation']; -export type SeamlessModeInvocation = s['SeamlessModeInvocation']; -export type MetadataInvocation = s['MetadataInvocation']; -export type CoreMetadataInvocation = s['CoreMetadataInvocation']; -export type MetadataItemInvocation = s['MetadataItemInvocation']; -export type MergeMetadataInvocation = s['MergeMetadataInvocation']; -export type IPAdapterMetadataField = s['IPAdapterMetadataField']; -export type T2IAdapterField = s['T2IAdapterField']; -export type LoRAMetadataField = s['LoRAMetadataField']; +export type CollectInvocation = S['CollectInvocation']; +export type IterateInvocation = S['IterateInvocation']; +export type RangeInvocation = S['RangeInvocation']; +export type RandomRangeInvocation = S['RandomRangeInvocation']; +export type RangeOfSizeInvocation = S['RangeOfSizeInvocation']; +export type ImageResizeInvocation = S['ImageResizeInvocation']; +export type ImageBlurInvocation = S['ImageBlurInvocation']; +export type ImageScaleInvocation = S['ImageScaleInvocation']; +export type InfillPatchMatchInvocation = S['InfillPatchMatchInvocation']; +export type InfillTileInvocation = S['InfillTileInvocation']; +export type CreateDenoiseMaskInvocation = S['CreateDenoiseMaskInvocation']; +export type MaskEdgeInvocation = S['MaskEdgeInvocation']; +export type RandomIntInvocation = S['RandomIntInvocation']; +export type CompelInvocation = S['CompelInvocation']; +export type DynamicPromptInvocation = S['DynamicPromptInvocation']; +export type NoiseInvocation = S['NoiseInvocation']; +export type DenoiseLatentsInvocation = S['DenoiseLatentsInvocation']; +export type SDXLLoraLoaderInvocation = S['SDXLLoraLoaderInvocation']; +export type ImageToLatentsInvocation = S['ImageToLatentsInvocation']; +export type LatentsToImageInvocation = S['LatentsToImageInvocation']; +export type ImageCollectionInvocation = S['ImageCollectionInvocation']; +export type MainModelLoaderInvocation = S['MainModelLoaderInvocation']; +export type LoraLoaderInvocation = S['LoraLoaderInvocation']; +export type ESRGANInvocation = S['ESRGANInvocation']; +export type DivideInvocation = S['DivideInvocation']; +export type ImageNSFWBlurInvocation = S['ImageNSFWBlurInvocation']; +export type ImageWatermarkInvocation = S['ImageWatermarkInvocation']; +export type SeamlessModeInvocation = S['SeamlessModeInvocation']; +export type MetadataInvocation = S['MetadataInvocation']; +export type CoreMetadataInvocation = S['CoreMetadataInvocation']; +export type MetadataItemInvocation = S['MetadataItemInvocation']; +export type MergeMetadataInvocation = S['MergeMetadataInvocation']; +export type IPAdapterMetadataField = S['IPAdapterMetadataField']; +export type T2IAdapterField = S['T2IAdapterField']; +export type LoRAMetadataField = S['LoRAMetadataField']; // ControlNet Nodes -export type ControlNetInvocation = s['ControlNetInvocation']; -export type T2IAdapterInvocation = s['T2IAdapterInvocation']; -export type IPAdapterInvocation = s['IPAdapterInvocation']; -export type CannyImageProcessorInvocation = s['CannyImageProcessorInvocation']; -export type ColorMapImageProcessorInvocation = s['ColorMapImageProcessorInvocation']; -export type ContentShuffleImageProcessorInvocation = s['ContentShuffleImageProcessorInvocation']; -export type DepthAnythingImageProcessorInvocation = s['DepthAnythingImageProcessorInvocation']; -export type HedImageProcessorInvocation = s['HedImageProcessorInvocation']; -export type LineartAnimeImageProcessorInvocation = s['LineartAnimeImageProcessorInvocation']; -export type LineartImageProcessorInvocation = s['LineartImageProcessorInvocation']; -export type MediapipeFaceProcessorInvocation = s['MediapipeFaceProcessorInvocation']; -export type MidasDepthImageProcessorInvocation = s['MidasDepthImageProcessorInvocation']; -export type MlsdImageProcessorInvocation = s['MlsdImageProcessorInvocation']; -export type NormalbaeImageProcessorInvocation = s['NormalbaeImageProcessorInvocation']; -export type DWOpenposeImageProcessorInvocation = s['DWOpenposeImageProcessorInvocation']; -export type PidiImageProcessorInvocation = s['PidiImageProcessorInvocation']; -export type ZoeDepthImageProcessorInvocation = s['ZoeDepthImageProcessorInvocation']; +export type ControlNetInvocation = S['ControlNetInvocation']; +export type T2IAdapterInvocation = S['T2IAdapterInvocation']; +export type IPAdapterInvocation = S['IPAdapterInvocation']; +export type CannyImageProcessorInvocation = S['CannyImageProcessorInvocation']; +export type ColorMapImageProcessorInvocation = S['ColorMapImageProcessorInvocation']; +export type ContentShuffleImageProcessorInvocation = S['ContentShuffleImageProcessorInvocation']; +export type DepthAnythingImageProcessorInvocation = S['DepthAnythingImageProcessorInvocation']; +export type HedImageProcessorInvocation = S['HedImageProcessorInvocation']; +export type LineartAnimeImageProcessorInvocation = S['LineartAnimeImageProcessorInvocation']; +export type LineartImageProcessorInvocation = S['LineartImageProcessorInvocation']; +export type MediapipeFaceProcessorInvocation = S['MediapipeFaceProcessorInvocation']; +export type MidasDepthImageProcessorInvocation = S['MidasDepthImageProcessorInvocation']; +export type MlsdImageProcessorInvocation = S['MlsdImageProcessorInvocation']; +export type NormalbaeImageProcessorInvocation = S['NormalbaeImageProcessorInvocation']; +export type DWOpenposeImageProcessorInvocation = S['DWOpenposeImageProcessorInvocation']; +export type PidiImageProcessorInvocation = S['PidiImageProcessorInvocation']; +export type ZoeDepthImageProcessorInvocation = S['ZoeDepthImageProcessorInvocation']; // Node Outputs -export type ImageOutput = s['ImageOutput']; -export type StringOutput = s['StringOutput']; -export type FloatOutput = s['FloatOutput']; -export type IntegerOutput = s['IntegerOutput']; -export type IterateInvocationOutput = s['IterateInvocationOutput']; -export type CollectInvocationOutput = s['CollectInvocationOutput']; -export type LatentsOutput = s['LatentsOutput']; -export type GraphInvocationOutput = s['GraphInvocationOutput']; +export type ImageOutput = S['ImageOutput']; +export type StringOutput = S['StringOutput']; +export type FloatOutput = S['FloatOutput']; +export type IntegerOutput = S['IntegerOutput']; +export type IterateInvocationOutput = S['IterateInvocationOutput']; +export type CollectInvocationOutput = S['CollectInvocationOutput']; +export type LatentsOutput = S['LatentsOutput']; +export type GraphInvocationOutput = S['GraphInvocationOutput']; // Post-image upload actions, controls workflows when images are uploaded From 79fb691b4d450dff65946c5c98db2a5726eeb68b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 15 Feb 2024 22:15:21 +1100 Subject: [PATCH 097/100] chore(ui): typegen --- .../frontend/web/src/services/api/schema.ts | 2024 +++++++---------- 1 file changed, 847 insertions(+), 1177 deletions(-) diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 1599b310c9a..3393e74d486 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -19,80 +19,14 @@ export type paths = { */ post: operations["parse_dynamicprompts"]; }; - "/api/v1/models/": { - /** - * List Models - * @description Gets a list of models - */ - get: operations["list_models"]; - }; - "/api/v1/models/{base_model}/{model_type}/{model_name}": { - /** - * Delete Model - * @description Delete Model - */ - delete: operations["del_model"]; - /** - * Update Model - * @description Update model contents with a new config. If the model name or base fields are changed, then the model is renamed. - */ - patch: operations["update_model"]; - }; - "/api/v1/models/import": { - /** - * Import Model - * @description Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically - */ - post: operations["import_model"]; - }; - "/api/v1/models/add": { - /** - * Add Model - * @description Add a model using the configuration information appropriate for its type. Only local models can be added by path - */ - post: operations["add_model"]; - }; - "/api/v1/models/convert/{base_model}/{model_type}/{model_name}": { - /** - * Convert Model - * @description Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none. - */ - put: operations["convert_model"]; - }; - "/api/v1/models/search": { - /** Search For Models */ - get: operations["search_for_models"]; - }; - "/api/v1/models/ckpt_confs": { - /** - * List Ckpt Configs - * @description Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT. - */ - get: operations["list_ckpt_configs"]; - }; - "/api/v1/models/sync": { - /** - * Sync To Config - * @description Call after making changes to models.yaml, autoimport directories or models directory to synchronize - * in-memory data structures with disk data structures. - */ - post: operations["sync_to_config"]; - }; - "/api/v1/models/merge/{base_model}": { - /** - * Merge Models - * @description Convert a checkpoint model into a diffusers model - */ - put: operations["merge_models"]; - }; - "/api/v1/model/record/": { + "/api/v2/models/": { /** * List Model Records * @description Get a list of models. */ get: operations["list_model_records"]; }; - "/api/v1/model/record/i/{key}": { + "/api/v2/models/i/{key}": { /** * Get Model Record * @description Get a model record @@ -112,50 +46,76 @@ export type paths = { */ patch: operations["update_model_record"]; }; - "/api/v1/model/record/meta": { + "/api/v2/models/summary": { /** * List Model Summary * @description Gets a page of model summary data. */ get: operations["list_model_summary"]; }; - "/api/v1/model/record/meta/i/{key}": { + "/api/v2/models/meta/i/{key}": { /** * Get Model Metadata * @description Get a model metadata object. */ get: operations["get_model_metadata"]; }; - "/api/v1/model/record/tags": { + "/api/v2/models/tags": { /** * List Tags * @description Get a unique set of all the model tags. */ get: operations["list_tags"]; }; - "/api/v1/model/record/tags/search": { + "/api/v2/models/tags/search": { /** * Search By Metadata Tags * @description Get a list of models. */ get: operations["search_by_metadata_tags"]; }; - "/api/v1/model/record/i/": { + "/api/v2/models/i/": { /** * Add Model Record * @description Add a model using the configuration information appropriate for its type. */ post: operations["add_model_record"]; }; - "/api/v1/model/record/import": { + "/api/v2/models/heuristic_import": { /** - * List Model Install Jobs - * @description Return list of model install jobs. + * Heuristic Import + * @description Install a model using a string identifier. + * + * `source` can be any of the following. + * + * 1. A path on the local filesystem ('C:\users\fred\model.safetensors') + * 2. A Url pointing to a single downloadable model file + * 3. A HuggingFace repo_id with any of the following formats: + * - model/name + * - model/name:fp16:vae + * - model/name::vae -- use default precision + * - model/name:fp16:path/to/model.safetensors + * - model/name::path/to/model.safetensors + * + * `config` is an optional dict containing model configuration values that will override + * the ones that are probed automatically. + * + * `access_token` is an optional access token for use with Urls that require + * authentication. + * + * Models will be downloaded, probed, configured and installed in a + * series of background threads. The return object has `status` attribute + * that can be used to monitor progress. + * + * See the documentation for `import_model_record` for more information on + * interpreting the job information returned by this route. */ - get: operations["list_model_install_jobs"]; + post: operations["heuristic_import_model"]; + }; + "/api/v2/models/install": { /** * Import Model - * @description Add a model using its local path, repo_id, or remote URL. + * @description Install a model using its local path, repo_id, or remote URL. * * Models will be downloaded, probed, configured and installed in a * series of background threads. The return object has `status` attribute @@ -166,32 +126,38 @@ export type paths = { * appropriate value: * * * To install a local path using LocalModelSource, pass a source of form: - * `{ + * ``` + * { * "type": "local", * "path": "/path/to/model", * "inplace": false - * }` - * The "inplace" flag, if true, will register the model in place in its - * current filesystem location. Otherwise, the model will be copied - * into the InvokeAI models directory. + * } + * ``` + * The "inplace" flag, if true, will register the model in place in its + * current filesystem location. Otherwise, the model will be copied + * into the InvokeAI models directory. * * * To install a HuggingFace repo_id using HFModelSource, pass a source of form: - * `{ + * ``` + * { * "type": "hf", * "repo_id": "stabilityai/stable-diffusion-2.0", * "variant": "fp16", * "subfolder": "vae", * "access_token": "f5820a918aaf01" - * }` - * The `variant`, `subfolder` and `access_token` fields are optional. + * } + * ``` + * The `variant`, `subfolder` and `access_token` fields are optional. * * * To install a remote model using an arbitrary URL, pass: - * `{ + * ``` + * { * "type": "url", * "url": "http://www.civitai.com/models/123456", * "access_token": "f5820a918aaf01" - * }` - * The `access_token` field is optonal + * } + * ``` + * The `access_token` field is optonal * * The model's configuration record will be probed and filled in * automatically. To override the default guesses, pass "metadata" @@ -200,26 +166,51 @@ export type paths = { * Installation occurs in the background. Either use list_model_install_jobs() * to poll for completion, or listen on the event bus for the following events: * - * "model_install_running" - * "model_install_completed" - * "model_install_error" + * * "model_install_running" + * * "model_install_completed" + * * "model_install_error" * * On successful completion, the event's payload will contain the field "key" * containing the installed ID of the model. On an error, the event's payload * will contain the fields "error_type" and "error" describing the nature of the * error and its traceback, respectively. */ - post: operations["import_model_record"]; + post: operations["import_model"]; + }; + "/api/v2/models/import": { + /** + * List Model Install Jobs + * @description Return the list of model install jobs. + * + * Install jobs have a numeric `id`, a `status`, and other fields that provide information on + * the nature of the job and its progress. The `status` is one of: + * + * * "waiting" -- Job is waiting in the queue to run + * * "downloading" -- Model file(s) are downloading + * * "running" -- Model has downloaded and the model probing and registration process is running + * * "completed" -- Installation completed successfully + * * "error" -- An error occurred. Details will be in the "error_type" and "error" fields. + * * "cancelled" -- Job was cancelled before completion. + * + * Once completed, information about the model such as its size, base + * model, type, and metadata can be retrieved from the `config_out` + * field. For multi-file models such as diffusers, information on individual files + * can be retrieved from `download_parts`. + * + * See the example and schema below for more information. + */ + get: operations["list_model_install_jobs"]; /** * Prune Model Install Jobs * @description Prune all completed and errored jobs from the install job list. */ patch: operations["prune_model_install_jobs"]; }; - "/api/v1/model/record/import/{id}": { + "/api/v2/models/import/{id}": { /** * Get Model Install Job - * @description Return model install job corresponding to the given source. + * @description Return model install job corresponding to the given source. See the documentation for 'List Model Install Jobs' + * for information on the format of the return value. */ get: operations["get_model_install_job"]; /** @@ -228,7 +219,7 @@ export type paths = { */ delete: operations["cancel_model_install_job"]; }; - "/api/v1/model/record/sync": { + "/api/v2/models/sync": { /** * Sync Models To Config * @description Traverse the models and autoimport directories. @@ -238,17 +229,29 @@ export type paths = { */ patch: operations["sync_models_to_config"]; }; - "/api/v1/model/record/merge": { + "/api/v2/models/convert/{key}": { + /** + * Convert Model + * @description Permanently convert a model into diffusers format, replacing the safetensors version. + * Note that during the conversion process the key and model hash will change. + * The return value is the model configuration for the converted model. + */ + put: operations["convert_model"]; + }; + "/api/v2/models/merge": { /** * Merge - * @description Merge diffusers models. - * - * keys: List of 2-3 model keys to merge together. All models must use the same base type. - * merged_model_name: Name for the merged model [Concat model names] - * alpha: Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5] - * force: If true, force the merge even if the models were generated by different versions of the diffusers library [False] - * interp: Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum] - * merge_dest_directory: Specify a directory to store the merged model in [models directory] + * @description Merge diffusers models. The process is controlled by a set parameters provided in the body of the request. + * ``` + * Argument Description [default] + * -------- ---------------------- + * keys List of 2-3 model keys to merge together. All models must use the same base type. + * merged_model_name Name for the merged model [Concat model names] + * alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5] + * force If true, force the merge even if the models were generated by different versions of the diffusers library [False] + * interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum] + * merge_dest_directory Specify a directory to store the merged model in [models directory] + * ``` */ put: operations["merge"]; }; @@ -815,6 +818,12 @@ export type components = { */ type?: "basemetadata"; }; + /** + * BaseModelType + * @description Base model type. + * @enum {string} + */ + BaseModelType: "any" | "sd-1" | "sd-2" | "sdxl" | "sdxl-refiner"; /** Batch */ Batch: { /** @@ -1163,19 +1172,6 @@ export type components = { }; /** Body_import_model */ Body_import_model: { - /** - * Location - * @description A model path, repo_id or URL to import - */ - location: string; - /** - * Prediction Type - * @description Prediction type for SDv2 checkpoints and rare SDv1 checkpoints - */ - prediction_type?: ("v_prediction" | "epsilon" | "sample") | null; - }; - /** Body_import_model_record */ - Body_import_model_record: { /** Source */ source: components["schemas"]["LocalModelSource"] | components["schemas"]["HFModelSource"] | components["schemas"]["CivitaiModelSource"] | components["schemas"]["URLModelSource"]; /** @@ -1216,11 +1212,6 @@ export type components = { */ merge_dest_directory?: string | null; }; - /** Body_merge_models */ - Body_merge_models: { - /** @description Model configuration */ - body: components["schemas"]["MergeModelsBody"]; - }; /** Body_parse_dynamicprompts */ Body_parse_dynamicprompts: { /** @@ -1412,11 +1403,18 @@ export type components = { * @description Model config for ClipVision. */ CLIPVisionDiffusersConfig: { - /** Path */ + /** + * Path + * @description filesystem path to the model file or directory + */ path: string; - /** Name */ + /** + * Name + * @description model name + */ name: string; - base: components["schemas"]["invokeai__backend__model_manager__config__BaseModelType"]; + /** @description base model */ + base: components["schemas"]["BaseModelType"]; /** * Type * @default clip_vision @@ -1444,45 +1442,29 @@ export type components = { * @description current fasthash of model contents */ current_hash?: string | null; - /** Description */ + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** * Source - * @description Model download source (URL or repo_id) + * @description model original source (path, URL or repo_id) */ source?: string | null; - }; - /** CLIPVisionModelDiffusersConfig */ - CLIPVisionModelDiffusersConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default clip_vision - * @constant - */ - model_type: "clip_vision"; - /** Path */ - path: string; - /** Description */ - description?: string | null; /** - * Model Format - * @constant + * Last Modified + * @description timestamp for modification time */ - model_format: "diffusers"; - error?: components["schemas"]["ModelError"] | null; + last_modified?: number | null; }; /** CLIPVisionModelField */ CLIPVisionModelField: { /** - * Model Name - * @description Name of the CLIP Vision image encoder model + * Key + * @description Key to the CLIP Vision image encoder model */ - model_name: string; - /** @description Base model (usually 'Any') */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; + key: string; }; /** * CV2 Infill @@ -2538,11 +2520,18 @@ export type components = { * @description Model config for ControlNet models (diffusers version). */ ControlNetCheckpointConfig: { - /** Path */ + /** + * Path + * @description filesystem path to the model file or directory + */ path: string; - /** Name */ + /** + * Name + * @description model name + */ name: string; - base: components["schemas"]["invokeai__backend__model_manager__config__BaseModelType"]; + /** @description base model */ + base: components["schemas"]["BaseModelType"]; /** * Type * @default controlnet @@ -2571,13 +2560,21 @@ export type components = { * @description current fasthash of model contents */ current_hash?: string | null; - /** Description */ + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** * Source - * @description Model download source (URL or repo_id) + * @description model original source (path, URL or repo_id) */ source?: string | null; + /** + * Last Modified + * @description timestamp for modification time + */ + last_modified?: number | null; /** * Config * @description path to the checkpoint model config file @@ -2589,11 +2586,18 @@ export type components = { * @description Model config for ControlNet models (diffusers version). */ ControlNetDiffusersConfig: { - /** Path */ + /** + * Path + * @description filesystem path to the model file or directory + */ path: string; - /** Name */ + /** + * Name + * @description model name + */ name: string; - base: components["schemas"]["invokeai__backend__model_manager__config__BaseModelType"]; + /** @description base model */ + base: components["schemas"]["BaseModelType"]; /** * Type * @default controlnet @@ -2622,13 +2626,23 @@ export type components = { * @description current fasthash of model contents */ current_hash?: string | null; - /** Description */ + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** * Source - * @description Model download source (URL or repo_id) + * @description model original source (path, URL or repo_id) */ source?: string | null; + /** + * Last Modified + * @description timestamp for modification time + */ + last_modified?: number | null; + /** @default */ + repo_variant?: components["schemas"]["ModelRepoVariant"] | null; }; /** * ControlNet @@ -2695,64 +2709,16 @@ export type components = { */ type: "controlnet"; }; - /** ControlNetModelCheckpointConfig */ - ControlNetModelCheckpointConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default controlnet - * @constant - */ - model_type: "controlnet"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - /** - * Model Format - * @constant - */ - model_format: "checkpoint"; - error?: components["schemas"]["ModelError"] | null; - /** Config */ - config: string; - }; - /** ControlNetModelDiffusersConfig */ - ControlNetModelDiffusersConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default controlnet - * @constant - */ - model_type: "controlnet"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - /** - * Model Format - * @constant - */ - model_format: "diffusers"; - error?: components["schemas"]["ModelError"] | null; - }; /** * ControlNetModelField * @description ControlNet model field */ ControlNetModelField: { /** - * Model Name - * @description Name of the ControlNet model + * Key + * @description Model config record key for the ControlNet model */ - model_name: string; - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; + key: string; }; /** * ControlOutput @@ -4246,7 +4212,7 @@ export type components = { * @description The nodes in this graph */ nodes?: { - [key: string]: components["schemas"]["ImageCropInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["IdealSizeInvocation"]; + [key: string]: components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"]; }; /** * Edges @@ -4283,7 +4249,7 @@ export type components = { * @description The results of node executions */ results: { - [key: string]: components["schemas"]["MetadataOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["String2Output"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["ClipSkipInvocationOutput"]; + [key: string]: components["schemas"]["ImageCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["String2Output"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["CLIPOutput"]; }; /** * Errors @@ -4477,11 +4443,18 @@ export type components = { * @description Model config for IP Adaptor format models. */ IPAdapterConfig: { - /** Path */ + /** + * Path + * @description filesystem path to the model file or directory + */ path: string; - /** Name */ + /** + * Name + * @description model name + */ name: string; - base: components["schemas"]["invokeai__backend__model_manager__config__BaseModelType"]; + /** @description base model */ + base: components["schemas"]["BaseModelType"]; /** * Type * @default ip_adapter @@ -4509,13 +4482,23 @@ export type components = { * @description current fasthash of model contents */ current_hash?: string | null; - /** Description */ + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** * Source - * @description Model download source (URL or repo_id) + * @description model original source (path, URL or repo_id) */ source?: string | null; + /** + * Last Modified + * @description timestamp for modification time + */ + last_modified?: number | null; + /** Image Encoder Model Id */ + image_encoder_model_id: string; }; /** IPAdapterField */ IPAdapterField: { @@ -4632,34 +4615,10 @@ export type components = { /** IPAdapterModelField */ IPAdapterModelField: { /** - * Model Name - * @description Name of the IP-Adapter model - */ - model_name: string; - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - }; - /** IPAdapterModelInvokeAIConfig */ - IPAdapterModelInvokeAIConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default ip_adapter - * @constant - */ - model_type: "ip_adapter"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - /** - * Model Format - * @constant + * Key + * @description Key to the IP-Adapter model */ - model_format: "invokeai"; - error?: components["schemas"]["ModelError"] | null; + key: string; }; /** IPAdapterOutput */ IPAdapterOutput: { @@ -6562,11 +6521,18 @@ export type components = { * @description Model config for LoRA/Lycoris models. */ LoRAConfig: { - /** Path */ + /** + * Path + * @description filesystem path to the model file or directory + */ path: string; - /** Name */ + /** + * Name + * @description model name + */ name: string; - base: components["schemas"]["invokeai__backend__model_manager__config__BaseModelType"]; + /** @description base model */ + base: components["schemas"]["BaseModelType"]; /** * Type * @default lora @@ -6594,13 +6560,21 @@ export type components = { * @description current fasthash of model contents */ current_hash?: string | null; - /** Description */ + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** * Source - * @description Model download source (URL or repo_id) + * @description model original source (path, URL or repo_id) */ source?: string | null; + /** + * Last Modified + * @description timestamp for modification time + */ + last_modified?: number | null; }; /** * LoRAMetadataField @@ -6615,42 +6589,17 @@ export type components = { */ weight: number; }; - /** LoRAModelConfig */ - LoRAModelConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default lora - * @constant - */ - model_type: "lora"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - model_format: components["schemas"]["LoRAModelFormat"]; - error?: components["schemas"]["ModelError"] | null; - }; /** * LoRAModelField * @description LoRA model field */ LoRAModelField: { /** - * Model Name - * @description Name of the LoRA model + * Key + * @description LoRA model key */ - model_name: string; - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; + key: string; }; - /** - * LoRAModelFormat - * @enum {string} - */ - LoRAModelFormat: "lycoris" | "diffusers"; /** * LocalModelSource * @description A local file or directory path. @@ -6678,16 +6627,12 @@ export type components = { /** LoraInfo */ LoraInfo: { /** - * Model Name - * @description Info to load submodel + * Key + * @description Key of model as returned by ModelRecordServiceBase.get_model() */ - model_name: string; - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** @description Info to load submodel */ - model_type: components["schemas"]["invokeai__backend__model_management__models__base__ModelType"]; + key: string; /** @description Info to load submodel */ - submodel?: components["schemas"]["SubModelType"] | null; + submodel_type?: components["schemas"]["SubModelType"] | null; /** * Weight * @description Lora's weight which to use when apply to model @@ -6771,11 +6716,18 @@ export type components = { * @description Model config for main checkpoint models. */ MainCheckpointConfig: { - /** Path */ + /** + * Path + * @description filesystem path to the model file or directory + */ path: string; - /** Name */ + /** + * Name + * @description model name + */ name: string; - base: components["schemas"]["invokeai__backend__model_manager__config__BaseModelType"]; + /** @description base model */ + base: components["schemas"]["BaseModelType"]; /** * Type * @default main @@ -6804,17 +6756,32 @@ export type components = { * @description current fasthash of model contents */ current_hash?: string | null; - /** Description */ + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** * Source - * @description Model download source (URL or repo_id) + * @description model original source (path, URL or repo_id) */ source?: string | null; + /** + * Last Modified + * @description timestamp for modification time + */ + last_modified?: number | null; /** Vae */ vae?: string | null; /** @default normal */ - variant?: components["schemas"]["invokeai__backend__model_manager__config__ModelVariantType"]; + variant?: components["schemas"]["ModelVariantType"]; + /** @default epsilon */ + prediction_type?: components["schemas"]["SchedulerPredictionType"]; + /** + * Upcast Attention + * @default false + */ + upcast_attention?: boolean; /** * Ztsnr Training * @default false @@ -6831,11 +6798,18 @@ export type components = { * @description Model config for main diffusers models. */ MainDiffusersConfig: { - /** Path */ + /** + * Path + * @description filesystem path to the model file or directory + */ path: string; - /** Name */ + /** + * Name + * @description model name + */ name: string; - base: components["schemas"]["invokeai__backend__model_manager__config__BaseModelType"]; + /** @description base model */ + base: components["schemas"]["BaseModelType"]; /** * Type * @default main @@ -6864,29 +6838,39 @@ export type components = { * @description current fasthash of model contents */ current_hash?: string | null; - /** Description */ + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** * Source - * @description Model download source (URL or repo_id) + * @description model original source (path, URL or repo_id) */ source?: string | null; + /** + * Last Modified + * @description timestamp for modification time + */ + last_modified?: number | null; /** Vae */ vae?: string | null; /** @default normal */ - variant?: components["schemas"]["invokeai__backend__model_manager__config__ModelVariantType"]; + variant?: components["schemas"]["ModelVariantType"]; + /** @default epsilon */ + prediction_type?: components["schemas"]["SchedulerPredictionType"]; /** - * Ztsnr Training + * Upcast Attention * @default false */ - ztsnr_training?: boolean; - /** @default epsilon */ - prediction_type?: components["schemas"]["invokeai__backend__model_manager__config__SchedulerPredictionType"]; + upcast_attention?: boolean; /** - * Upcast Attention + * Ztsnr Training * @default false */ - upcast_attention?: boolean; + ztsnr_training?: boolean; + /** @default */ + repo_variant?: components["schemas"]["ModelRepoVariant"] | null; }; /** * MainModelField @@ -6894,14 +6878,10 @@ export type components = { */ MainModelField: { /** - * Model Name - * @description Name of the model + * Key + * @description Model key */ - model_name: string; - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** @description Model Type */ - model_type: components["schemas"]["invokeai__backend__model_management__models__base__ModelType"]; + key: string; }; /** * Main Model @@ -7153,38 +7133,6 @@ export type components = { */ type: "merge_metadata"; }; - /** MergeModelsBody */ - MergeModelsBody: { - /** - * Model Names - * @description model name - */ - model_names: string[]; - /** - * Merged Model Name - * @description Name of destination model - */ - merged_model_name: string | null; - /** - * Alpha - * @description Alpha weighting strength to apply to 2d and 3d models - * @default 0.5 - */ - alpha?: number | null; - /** @description Interpolation method */ - interp: components["schemas"]["MergeInterpolationMethod"] | null; - /** - * Force - * @description Force merging of models created with different versions of diffusers - * @default false - */ - force?: boolean | null; - /** - * Merge Dest Directory - * @description Save the merged model to the designated directory (with 'merged_model_name' appended) - */ - merge_dest_directory?: string | null; - }; /** * Merge Tiles to Image * @description Merge multiple tile images into a single image. @@ -7459,11 +7407,6 @@ export type components = { */ type: "mlsd_image_processor"; }; - /** - * ModelError - * @constant - */ - ModelError: "not_found"; /** * ModelFormat * @description Storage format of model. @@ -7473,16 +7416,12 @@ export type components = { /** ModelInfo */ ModelInfo: { /** - * Model Name - * @description Info to load submodel + * Key + * @description Key of model as returned by ModelRecordServiceBase.get_model() */ - model_name: string; - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** @description Info to load submodel */ - model_type: components["schemas"]["invokeai__backend__model_management__models__base__ModelType"]; + key: string; /** @description Info to load submodel */ - submodel?: components["schemas"]["SubModelType"] | null; + submodel_type?: components["schemas"]["SubModelType"] | null; }; /** * ModelInstallJob @@ -7508,7 +7447,7 @@ export type components = { * Config Out * @description After successful installation, this will hold the configuration object. */ - config_out?: (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"] | null; + config_out?: (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"] | components["schemas"]["ONNXSDXLConfig"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"] | null; /** * Inplace * @description Leave model in its current location; otherwise install under models directory @@ -7587,7 +7526,7 @@ export type components = { * @description Various hugging face variants on the diffusers format. * @enum {string} */ - ModelRepoVariant: "default" | "fp16" | "fp32" | "onnx" | "openvino" | "flax"; + ModelRepoVariant: "" | "fp16" | "fp32" | "onnx" | "openvino" | "flax"; /** * ModelSummary * @description A short summary of models for UI listing purposes. @@ -7599,9 +7538,9 @@ export type components = { */ key: string; /** @description model type */ - type: components["schemas"]["invokeai__backend__model_manager__config__ModelType"]; + type: components["schemas"]["ModelType"]; /** @description base model */ - base: components["schemas"]["invokeai__backend__model_manager__config__BaseModelType"]; + base: components["schemas"]["BaseModelType"]; /** @description model format */ format: components["schemas"]["ModelFormat"]; /** @@ -7620,6 +7559,26 @@ export type components = { */ tags: string[]; }; + /** + * ModelType + * @description Model type. + * @enum {string} + */ + ModelType: "onnx" | "main" | "vae" | "lora" | "controlnet" | "embedding" | "ip_adapter" | "clip_vision" | "t2i_adapter"; + /** + * ModelVariantType + * @description Variant type. + * @enum {string} + */ + ModelVariantType: "normal" | "inpaint" | "depth"; + /** + * ModelsList + * @description Return list of configs. + */ + ModelsList: { + /** Models */ + models: ((components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"] | components["schemas"]["ONNXSDXLConfig"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"])[]; + }; /** * Multiply Integers * @description Multiplies two numbers @@ -7808,9 +7767,15 @@ export type components = { * @description Model config for ONNX format models based on sd-1. */ ONNXSD1Config: { - /** Path */ + /** + * Path + * @description filesystem path to the model file or directory + */ path: string; - /** Name */ + /** + * Name + * @description model name + */ name: string; /** * Base @@ -7845,38 +7810,52 @@ export type components = { * @description current fasthash of model contents */ current_hash?: string | null; - /** Description */ + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** * Source - * @description Model download source (URL or repo_id) + * @description model original source (path, URL or repo_id) */ source?: string | null; + /** + * Last Modified + * @description timestamp for modification time + */ + last_modified?: number | null; /** Vae */ vae?: string | null; /** @default normal */ - variant?: components["schemas"]["invokeai__backend__model_manager__config__ModelVariantType"]; - /** - * Ztsnr Training - * @default false - */ - ztsnr_training?: boolean; + variant?: components["schemas"]["ModelVariantType"]; /** @default epsilon */ - prediction_type?: components["schemas"]["invokeai__backend__model_manager__config__SchedulerPredictionType"]; + prediction_type?: components["schemas"]["SchedulerPredictionType"]; /** * Upcast Attention * @default false */ upcast_attention?: boolean; + /** + * Ztsnr Training + * @default false + */ + ztsnr_training?: boolean; }; /** * ONNXSD2Config * @description Model config for ONNX format models based on sd-2. */ ONNXSD2Config: { - /** Path */ + /** + * Path + * @description filesystem path to the model file or directory + */ path: string; - /** Name */ + /** + * Name + * @description model name + */ name: string; /** * Base @@ -7911,78 +7890,117 @@ export type components = { * @description current fasthash of model contents */ current_hash?: string | null; - /** Description */ + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** * Source - * @description Model download source (URL or repo_id) + * @description model original source (path, URL or repo_id) */ source?: string | null; + /** + * Last Modified + * @description timestamp for modification time + */ + last_modified?: number | null; /** Vae */ vae?: string | null; /** @default normal */ - variant?: components["schemas"]["invokeai__backend__model_manager__config__ModelVariantType"]; - /** - * Ztsnr Training - * @default false - */ - ztsnr_training?: boolean; + variant?: components["schemas"]["ModelVariantType"]; /** @default v_prediction */ - prediction_type?: components["schemas"]["invokeai__backend__model_manager__config__SchedulerPredictionType"]; + prediction_type?: components["schemas"]["SchedulerPredictionType"]; /** * Upcast Attention * @default true */ upcast_attention?: boolean; + /** + * Ztsnr Training + * @default false + */ + ztsnr_training?: boolean; }; - /** ONNXStableDiffusion1ModelConfig */ - ONNXStableDiffusion1ModelConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; + /** + * ONNXSDXLConfig + * @description Model config for ONNX format models based on sdxl. + */ + ONNXSDXLConfig: { /** - * Model Type - * @default onnx - * @constant + * Path + * @description filesystem path to the model file or directory */ - model_type: "onnx"; - /** Path */ path: string; - /** Description */ - description?: string | null; /** - * Model Format + * Name + * @description model name + */ + name: string; + /** + * Base + * @default sdxl * @constant */ - model_format: "onnx"; - error?: components["schemas"]["ModelError"] | null; - variant: components["schemas"]["invokeai__backend__model_management__models__base__ModelVariantType"]; - }; - /** ONNXStableDiffusion2ModelConfig */ - ONNXStableDiffusion2ModelConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; + base?: "sdxl"; /** - * Model Type + * Type * @default onnx * @constant */ - model_type: "onnx"; - /** Path */ - path: string; - /** Description */ + type?: "onnx"; + /** + * Format + * @enum {string} + */ + format: "onnx" | "olive"; + /** + * Key + * @description unique key for model + * @default + */ + key?: string; + /** + * Original Hash + * @description original fasthash of model contents + */ + original_hash?: string | null; + /** + * Current Hash + * @description current fasthash of model contents + */ + current_hash?: string | null; + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** - * Model Format - * @constant + * Source + * @description model original source (path, URL or repo_id) + */ + source?: string | null; + /** + * Last Modified + * @description timestamp for modification time + */ + last_modified?: number | null; + /** Vae */ + vae?: string | null; + /** @default normal */ + variant?: components["schemas"]["ModelVariantType"]; + /** @default v_prediction */ + prediction_type?: components["schemas"]["SchedulerPredictionType"]; + /** + * Upcast Attention + * @default false + */ + upcast_attention?: boolean; + /** + * Ztsnr Training + * @default false */ - model_format: "onnx"; - error?: components["schemas"]["ModelError"] | null; - variant: components["schemas"]["invokeai__backend__model_management__models__base__ModelVariantType"]; - prediction_type: components["schemas"]["invokeai__backend__model_management__models__base__SchedulerPredictionType"]; - /** Upcast Attention */ - upcast_attention: boolean; + ztsnr_training?: boolean; }; /** OffsetPaginatedResults[BoardDTO] */ OffsetPaginatedResults_BoardDTO_: { @@ -9119,6 +9137,12 @@ export type components = { */ type: "scheduler_output"; }; + /** + * SchedulerPredictionType + * @description Scheduler prediction type. + * @enum {string} + */ + SchedulerPredictionType: "epsilon" | "v_prediction" | "sample"; /** * Seamless * @description Applies the seamless transformation to the Model UNet and VAE. @@ -9468,162 +9492,6 @@ export type components = { */ type: "show_image"; }; - /** StableDiffusion1ModelCheckpointConfig */ - StableDiffusion1ModelCheckpointConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default main - * @constant - */ - model_type: "main"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - /** - * Model Format - * @constant - */ - model_format: "checkpoint"; - error?: components["schemas"]["ModelError"] | null; - /** Vae */ - vae?: string | null; - /** Config */ - config: string; - variant: components["schemas"]["invokeai__backend__model_management__models__base__ModelVariantType"]; - }; - /** StableDiffusion1ModelDiffusersConfig */ - StableDiffusion1ModelDiffusersConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default main - * @constant - */ - model_type: "main"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - /** - * Model Format - * @constant - */ - model_format: "diffusers"; - error?: components["schemas"]["ModelError"] | null; - /** Vae */ - vae?: string | null; - variant: components["schemas"]["invokeai__backend__model_management__models__base__ModelVariantType"]; - }; - /** StableDiffusion2ModelCheckpointConfig */ - StableDiffusion2ModelCheckpointConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default main - * @constant - */ - model_type: "main"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - /** - * Model Format - * @constant - */ - model_format: "checkpoint"; - error?: components["schemas"]["ModelError"] | null; - /** Vae */ - vae?: string | null; - /** Config */ - config: string; - variant: components["schemas"]["invokeai__backend__model_management__models__base__ModelVariantType"]; - }; - /** StableDiffusion2ModelDiffusersConfig */ - StableDiffusion2ModelDiffusersConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default main - * @constant - */ - model_type: "main"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - /** - * Model Format - * @constant - */ - model_format: "diffusers"; - error?: components["schemas"]["ModelError"] | null; - /** Vae */ - vae?: string | null; - variant: components["schemas"]["invokeai__backend__model_management__models__base__ModelVariantType"]; - }; - /** StableDiffusionXLModelCheckpointConfig */ - StableDiffusionXLModelCheckpointConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default main - * @constant - */ - model_type: "main"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - /** - * Model Format - * @constant - */ - model_format: "checkpoint"; - error?: components["schemas"]["ModelError"] | null; - /** Vae */ - vae?: string | null; - /** Config */ - config: string; - variant: components["schemas"]["invokeai__backend__model_management__models__base__ModelVariantType"]; - }; - /** StableDiffusionXLModelDiffusersConfig */ - StableDiffusionXLModelDiffusersConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default main - * @constant - */ - model_type: "main"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - /** - * Model Format - * @constant - */ - model_format: "diffusers"; - error?: components["schemas"]["ModelError"] | null; - /** Vae */ - vae?: string | null; - variant: components["schemas"]["invokeai__backend__model_management__models__base__ModelVariantType"]; - }; /** * Step Param Easing * @description Experimental per-step parameter easing for denoising steps @@ -10079,6 +9947,7 @@ export type components = { }; /** * SubModelType + * @description Submodel type. * @enum {string} */ SubModelType: "unet" | "text_encoder" | "text_encoder_2" | "tokenizer" | "tokenizer_2" | "vae" | "vae_decoder" | "vae_encoder" | "scheduler" | "safety_checker"; @@ -10216,37 +10085,13 @@ export type components = { */ type: "t2i_adapter"; }; - /** T2IAdapterModelDiffusersConfig */ - T2IAdapterModelDiffusersConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; + /** T2IAdapterModelField */ + T2IAdapterModelField: { /** - * Model Type - * @default t2i_adapter - * @constant + * Key + * @description Model record key for the T2I-Adapter model */ - model_type: "t2i_adapter"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - /** - * Model Format - * @constant - */ - model_format: "diffusers"; - error?: components["schemas"]["ModelError"] | null; - }; - /** T2IAdapterModelField */ - T2IAdapterModelField: { - /** - * Model Name - * @description Name of the T2I-Adapter model - */ - model_name: string; - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; + key: string; }; /** T2IAdapterOutput */ T2IAdapterOutput: { @@ -10267,11 +10112,18 @@ export type components = { * @description Model config for T2I. */ T2IConfig: { - /** Path */ + /** + * Path + * @description filesystem path to the model file or directory + */ path: string; - /** Name */ + /** + * Name + * @description model name + */ name: string; - base: components["schemas"]["invokeai__backend__model_manager__config__BaseModelType"]; + /** @description base model */ + base: components["schemas"]["BaseModelType"]; /** * Type * @default t2i_adapter @@ -10299,13 +10151,21 @@ export type components = { * @description current fasthash of model contents */ current_hash?: string | null; - /** Description */ + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** * Source - * @description Model download source (URL or repo_id) + * @description model original source (path, URL or repo_id) */ source?: string | null; + /** + * Last Modified + * @description timestamp for modification time + */ + last_modified?: number | null; }; /** TBLR */ TBLR: { @@ -10323,11 +10183,18 @@ export type components = { * @description Model config for textual inversion embeddings. */ TextualInversionConfig: { - /** Path */ + /** + * Path + * @description filesystem path to the model file or directory + */ path: string; - /** Name */ + /** + * Name + * @description model name + */ name: string; - base: components["schemas"]["invokeai__backend__model_manager__config__BaseModelType"]; + /** @description base model */ + base: components["schemas"]["BaseModelType"]; /** * Type * @default embedding @@ -10355,32 +10222,21 @@ export type components = { * @description current fasthash of model contents */ current_hash?: string | null; - /** Description */ + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** * Source - * @description Model download source (URL or repo_id) + * @description model original source (path, URL or repo_id) */ source?: string | null; - }; - /** TextualInversionModelConfig */ - TextualInversionModelConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; /** - * Model Type - * @default embedding - * @constant + * Last Modified + * @description timestamp for modification time */ - model_type: "embedding"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - /** Model Format */ - model_format: null; - error?: components["schemas"]["ModelError"] | null; + last_modified?: number | null; }; /** Tile */ Tile: { @@ -10546,7 +10402,7 @@ export type components = { }; /** * UNetOutput - * @description Base class for invocations that output a UNet field + * @description Base class for invocations that output a UNet field. */ UNetOutput: { /** @@ -10646,12 +10502,10 @@ export type components = { */ VAEModelField: { /** - * Model Name - * @description Name of the model + * Key + * @description Model's key */ - model_name: string; - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; + key: string; }; /** * VAEOutput @@ -10675,11 +10529,18 @@ export type components = { * @description Model config for standalone VAE models. */ VaeCheckpointConfig: { - /** Path */ + /** + * Path + * @description filesystem path to the model file or directory + */ path: string; - /** Name */ + /** + * Name + * @description model name + */ name: string; - base: components["schemas"]["invokeai__backend__model_manager__config__BaseModelType"]; + /** @description base model */ + base: components["schemas"]["BaseModelType"]; /** * Type * @default vae @@ -10708,24 +10569,39 @@ export type components = { * @description current fasthash of model contents */ current_hash?: string | null; - /** Description */ + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** * Source - * @description Model download source (URL or repo_id) + * @description model original source (path, URL or repo_id) */ source?: string | null; + /** + * Last Modified + * @description timestamp for modification time + */ + last_modified?: number | null; }; /** * VaeDiffusersConfig * @description Model config for standalone VAE models (diffusers version). */ VaeDiffusersConfig: { - /** Path */ + /** + * Path + * @description filesystem path to the model file or directory + */ path: string; - /** Name */ + /** + * Name + * @description model name + */ name: string; - base: components["schemas"]["invokeai__backend__model_manager__config__BaseModelType"]; + /** @description base model */ + base: components["schemas"]["BaseModelType"]; /** * Type * @default vae @@ -10754,13 +10630,21 @@ export type components = { * @description current fasthash of model contents */ current_hash?: string | null; - /** Description */ + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** * Source - * @description Model download source (URL or repo_id) + * @description model original source (path, URL or repo_id) */ source?: string | null; + /** + * Last Modified + * @description timestamp for modification time + */ + last_modified?: number | null; }; /** VaeField */ VaeField: { @@ -10806,29 +10690,6 @@ export type components = { */ type: "vae_loader"; }; - /** VaeModelConfig */ - VaeModelConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default vae - * @constant - */ - model_type: "vae"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - model_format: components["schemas"]["VaeModelFormat"]; - error?: components["schemas"]["ModelError"] | null; - }; - /** - * VaeModelFormat - * @enum {string} - */ - VaeModelFormat: "checkpoint" | "diffusers"; /** ValidationError */ ValidationError: { /** Location */ @@ -11085,63 +10946,6 @@ export type components = { */ type: "zoe_depth_image_processor"; }; - /** - * ModelsList - * @description Return list of configs. - */ - invokeai__app__api__routers__model_records__ModelsList: { - /** Models */ - models: ((components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"])[]; - }; - /** ModelsList */ - invokeai__app__api__routers__models__ModelsList: { - /** Models */ - models: (components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["IPAdapterModelInvokeAIConfig"] | components["schemas"]["CLIPVisionModelDiffusersConfig"] | components["schemas"]["T2IAdapterModelDiffusersConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"])[]; - }; - /** - * BaseModelType - * @enum {string} - */ - invokeai__backend__model_management__models__base__BaseModelType: "any" | "sd-1" | "sd-2" | "sdxl" | "sdxl-refiner"; - /** - * ModelType - * @enum {string} - */ - invokeai__backend__model_management__models__base__ModelType: "onnx" | "main" | "vae" | "lora" | "controlnet" | "embedding" | "ip_adapter" | "clip_vision" | "t2i_adapter"; - /** - * ModelVariantType - * @enum {string} - */ - invokeai__backend__model_management__models__base__ModelVariantType: "normal" | "inpaint" | "depth"; - /** - * SchedulerPredictionType - * @enum {string} - */ - invokeai__backend__model_management__models__base__SchedulerPredictionType: "epsilon" | "v_prediction" | "sample"; - /** - * BaseModelType - * @description Base model type. - * @enum {string} - */ - invokeai__backend__model_manager__config__BaseModelType: "any" | "sd-1" | "sd-2" | "sdxl" | "sdxl-refiner"; - /** - * ModelType - * @description Model type. - * @enum {string} - */ - invokeai__backend__model_manager__config__ModelType: "onnx" | "main" | "vae" | "lora" | "controlnet" | "embedding" | "ip_adapter" | "clip_vision" | "t2i_adapter"; - /** - * ModelVariantType - * @description Variant type. - * @enum {string} - */ - invokeai__backend__model_manager__config__ModelVariantType: "normal" | "inpaint" | "depth"; - /** - * SchedulerPredictionType - * @description Scheduler prediction type. - * @enum {string} - */ - invokeai__backend__model_manager__config__SchedulerPredictionType: "epsilon" | "v_prediction" | "sample"; /** * Classification * @description The classification of an Invocation. @@ -11309,17 +11113,17 @@ export type components = { */ UIType: "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_MainModel" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict"; /** - * StableDiffusionXLModelFormat + * VaeModelFormat * @description An enumeration. * @enum {string} */ - StableDiffusionXLModelFormat: "checkpoint" | "diffusers"; + VaeModelFormat: "checkpoint" | "diffusers"; /** - * T2IAdapterModelFormat + * StableDiffusion2ModelFormat * @description An enumeration. * @enum {string} */ - T2IAdapterModelFormat: "diffusers"; + StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; /** * CLIPVisionModelFormat * @description An enumeration. @@ -11327,11 +11131,11 @@ export type components = { */ CLIPVisionModelFormat: "diffusers"; /** - * ControlNetModelFormat + * StableDiffusionXLModelFormat * @description An enumeration. * @enum {string} */ - ControlNetModelFormat: "checkpoint" | "diffusers"; + StableDiffusionXLModelFormat: "checkpoint" | "diffusers"; /** * StableDiffusion1ModelFormat * @description An enumeration. @@ -11339,23 +11143,35 @@ export type components = { */ StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; /** - * IPAdapterModelFormat + * ControlNetModelFormat * @description An enumeration. * @enum {string} */ - IPAdapterModelFormat: "invokeai"; + ControlNetModelFormat: "checkpoint" | "diffusers"; /** - * StableDiffusion2ModelFormat + * StableDiffusionOnnxModelFormat * @description An enumeration. * @enum {string} */ - StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; + StableDiffusionOnnxModelFormat: "olive" | "onnx"; /** - * StableDiffusionOnnxModelFormat + * T2IAdapterModelFormat * @description An enumeration. * @enum {string} */ - StableDiffusionOnnxModelFormat: "olive" | "onnx"; + T2IAdapterModelFormat: "diffusers"; + /** + * LoRAModelFormat + * @description An enumeration. + * @enum {string} + */ + LoRAModelFormat: "lycoris" | "diffusers"; + /** + * IPAdapterModelFormat + * @description An enumeration. + * @enum {string} + */ + IPAdapterModelFormat: "invokeai"; }; responses: never; parameters: never; @@ -11426,25 +11242,63 @@ export type operations = { }; }; /** - * List Models - * @description Gets a list of models + * List Model Records + * @description Get a list of models. */ - list_models: { + list_model_records: { parameters: { query?: { /** @description Base models to include */ - base_models?: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"][] | null; + base_models?: components["schemas"]["BaseModelType"][] | null; /** @description The type of model to get */ - model_type?: components["schemas"]["invokeai__backend__model_management__models__base__ModelType"] | null; + model_type?: components["schemas"]["ModelType"] | null; + /** @description Exact match on the name of the model */ + model_name?: string | null; + /** @description Exact match on the format of the model (e.g. 'diffusers') */ + model_format?: components["schemas"]["ModelFormat"] | null; }; }; responses: { /** @description Successful Response */ 200: { content: { - "application/json": components["schemas"]["invokeai__app__api__routers__models__ModelsList"]; + "application/json": components["schemas"]["ModelsList"]; + }; + }; + /** @description Validation Error */ + 422: { + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; + /** + * Get Model Record + * @description Get a model record + */ + get_model_record: { + parameters: { + path: { + /** @description Key of the model record to fetch. */ + key: string; + }; + }; + responses: { + /** @description The model configuration was retrieved successfully */ + 200: { + content: { + "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"] | components["schemas"]["ONNXSDXLConfig"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; }; }; + /** @description Bad request */ + 400: { + content: never; + }; + /** @description The model could not be found */ + 404: { + content: never; + }; /** @description Validation Error */ 422: { content: { @@ -11454,18 +11308,17 @@ export type operations = { }; }; /** - * Delete Model - * @description Delete Model + * Del Model Record + * @description Delete model record from database. + * + * The configuration record will be removed. The corresponding weights files will be + * deleted as well if they reside within the InvokeAI "models" directory. */ - del_model: { + del_model_record: { parameters: { path: { - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** @description The type of model */ - model_type: components["schemas"]["invokeai__backend__model_management__models__base__ModelType"]; - /** @description model name */ - model_name: string; + /** @description Unique key of model to remove from model registry. */ + key: string; }; }; responses: { @@ -11486,30 +11339,38 @@ export type operations = { }; }; /** - * Update Model + * Update Model Record * @description Update model contents with a new config. If the model name or base fields are changed, then the model is renamed. */ - update_model: { + update_model_record: { parameters: { path: { - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** @description The type of model */ - model_type: components["schemas"]["invokeai__backend__model_management__models__base__ModelType"]; - /** @description model name */ - model_name: string; + /** @description Unique key of model */ + key: string; }; }; requestBody: { content: { - "application/json": components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["IPAdapterModelInvokeAIConfig"] | components["schemas"]["CLIPVisionModelDiffusersConfig"] | components["schemas"]["T2IAdapterModelDiffusersConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"]; + /** + * @example { + * "path": "/path/to/model", + * "name": "model_name", + * "base": "sd-1", + * "type": "main", + * "format": "checkpoint", + * "config": "configs/stable-diffusion/v1-inference.yaml", + * "description": "Model description", + * "variant": "normal" + * } + */ + "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"] | components["schemas"]["ONNXSDXLConfig"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; }; }; responses: { /** @description The model was updated successfully */ 200: { content: { - "application/json": components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["IPAdapterModelInvokeAIConfig"] | components["schemas"]["CLIPVisionModelDiffusersConfig"] | components["schemas"]["T2IAdapterModelDiffusersConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"]; + "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"] | components["schemas"]["ONNXSDXLConfig"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; }; }; /** @description Bad request */ @@ -11533,387 +11394,32 @@ export type operations = { }; }; /** - * Import Model - * @description Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically + * List Model Summary + * @description Gets a page of model summary data. */ - import_model: { - requestBody: { - content: { - "application/json": components["schemas"]["Body_import_model"]; + list_model_summary: { + parameters: { + query?: { + /** @description The page to get */ + page?: number; + /** @description The number of models per page */ + per_page?: number; + /** @description The attribute to order by */ + order_by?: components["schemas"]["ModelRecordOrderBy"]; }; }; responses: { - /** @description The model imported successfully */ - 201: { + /** @description Successful Response */ + 200: { content: { - "application/json": components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["IPAdapterModelInvokeAIConfig"] | components["schemas"]["CLIPVisionModelDiffusersConfig"] | components["schemas"]["T2IAdapterModelDiffusersConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"]; + "application/json": components["schemas"]["PaginatedResults_ModelSummary_"]; }; }; - /** @description The model could not be found */ - 404: { - content: never; - }; - /** @description There is already a model corresponding to this path or repo_id */ - 409: { - content: never; - }; - /** @description Unrecognized file/folder format */ - 415: { - content: never; - }; - /** @description Validation Error */ - 422: { - content: { - "application/json": components["schemas"]["HTTPValidationError"]; - }; - }; - /** @description The model appeared to import successfully, but could not be found in the model manager */ - 424: { - content: never; - }; - }; - }; - /** - * Add Model - * @description Add a model using the configuration information appropriate for its type. Only local models can be added by path - */ - add_model: { - requestBody: { - content: { - "application/json": components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["IPAdapterModelInvokeAIConfig"] | components["schemas"]["CLIPVisionModelDiffusersConfig"] | components["schemas"]["T2IAdapterModelDiffusersConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"]; - }; - }; - responses: { - /** @description The model added successfully */ - 201: { - content: { - "application/json": components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["IPAdapterModelInvokeAIConfig"] | components["schemas"]["CLIPVisionModelDiffusersConfig"] | components["schemas"]["T2IAdapterModelDiffusersConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"]; - }; - }; - /** @description The model could not be found */ - 404: { - content: never; - }; - /** @description There is already a model corresponding to this path or repo_id */ - 409: { - content: never; - }; - /** @description Validation Error */ - 422: { - content: { - "application/json": components["schemas"]["HTTPValidationError"]; - }; - }; - /** @description The model appeared to add successfully, but could not be found in the model manager */ - 424: { - content: never; - }; - }; - }; - /** - * Convert Model - * @description Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none. - */ - convert_model: { - parameters: { - query?: { - /** @description Save the converted model to the designated directory */ - convert_dest_directory?: string | null; - }; - path: { - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** @description The type of model */ - model_type: components["schemas"]["invokeai__backend__model_management__models__base__ModelType"]; - /** @description model name */ - model_name: string; - }; - }; - responses: { - /** @description Model converted successfully */ - 200: { - content: { - "application/json": components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["IPAdapterModelInvokeAIConfig"] | components["schemas"]["CLIPVisionModelDiffusersConfig"] | components["schemas"]["T2IAdapterModelDiffusersConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"]; - }; - }; - /** @description Bad request */ - 400: { - content: never; - }; - /** @description Model not found */ - 404: { - content: never; - }; - /** @description Validation Error */ - 422: { - content: { - "application/json": components["schemas"]["HTTPValidationError"]; - }; - }; - }; - }; - /** Search For Models */ - search_for_models: { - parameters: { - query: { - /** @description Directory path to search for models */ - search_path: string; - }; - }; - responses: { - /** @description Directory searched successfully */ - 200: { - content: { - "application/json": string[]; - }; - }; - /** @description Invalid directory path */ - 404: { - content: never; - }; - /** @description Validation Error */ - 422: { - content: { - "application/json": components["schemas"]["HTTPValidationError"]; - }; - }; - }; - }; - /** - * List Ckpt Configs - * @description Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT. - */ - list_ckpt_configs: { - responses: { - /** @description paths retrieved successfully */ - 200: { - content: { - "application/json": string[]; - }; - }; - }; - }; - /** - * Sync To Config - * @description Call after making changes to models.yaml, autoimport directories or models directory to synchronize - * in-memory data structures with disk data structures. - */ - sync_to_config: { - responses: { - /** @description synchronization successful */ - 201: { - content: { - "application/json": boolean; - }; - }; - }; - }; - /** - * Merge Models - * @description Convert a checkpoint model into a diffusers model - */ - merge_models: { - parameters: { - path: { - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - }; - }; - requestBody: { - content: { - "application/json": components["schemas"]["Body_merge_models"]; - }; - }; - responses: { - /** @description Model converted successfully */ - 200: { - content: { - "application/json": components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["IPAdapterModelInvokeAIConfig"] | components["schemas"]["CLIPVisionModelDiffusersConfig"] | components["schemas"]["T2IAdapterModelDiffusersConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"]; - }; - }; - /** @description Incompatible models */ - 400: { - content: never; - }; - /** @description One or more models not found */ - 404: { - content: never; - }; - /** @description Validation Error */ - 422: { - content: { - "application/json": components["schemas"]["HTTPValidationError"]; - }; - }; - }; - }; - /** - * List Model Records - * @description Get a list of models. - */ - list_model_records: { - parameters: { - query?: { - /** @description Base models to include */ - base_models?: components["schemas"]["invokeai__backend__model_manager__config__BaseModelType"][] | null; - /** @description The type of model to get */ - model_type?: components["schemas"]["invokeai__backend__model_manager__config__ModelType"] | null; - /** @description Exact match on the name of the model */ - model_name?: string | null; - /** @description Exact match on the format of the model (e.g. 'diffusers') */ - model_format?: components["schemas"]["ModelFormat"] | null; - }; - }; - responses: { - /** @description Successful Response */ - 200: { - content: { - "application/json": components["schemas"]["invokeai__app__api__routers__model_records__ModelsList"]; - }; - }; - /** @description Validation Error */ - 422: { - content: { - "application/json": components["schemas"]["HTTPValidationError"]; - }; - }; - }; - }; - /** - * Get Model Record - * @description Get a model record - */ - get_model_record: { - parameters: { - path: { - /** @description Key of the model record to fetch. */ - key: string; - }; - }; - responses: { - /** @description Success */ - 200: { - content: { - "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; - }; - }; - /** @description Bad request */ - 400: { - content: never; - }; - /** @description The model could not be found */ - 404: { - content: never; - }; - /** @description Validation Error */ - 422: { - content: { - "application/json": components["schemas"]["HTTPValidationError"]; - }; - }; - }; - }; - /** - * Del Model Record - * @description Delete model record from database. - * - * The configuration record will be removed. The corresponding weights files will be - * deleted as well if they reside within the InvokeAI "models" directory. - */ - del_model_record: { - parameters: { - path: { - /** @description Unique key of model to remove from model registry. */ - key: string; - }; - }; - responses: { - /** @description Model deleted successfully */ - 204: { - content: never; - }; - /** @description Model not found */ - 404: { - content: never; - }; - /** @description Validation Error */ - 422: { - content: { - "application/json": components["schemas"]["HTTPValidationError"]; - }; - }; - }; - }; - /** - * Update Model Record - * @description Update model contents with a new config. If the model name or base fields are changed, then the model is renamed. - */ - update_model_record: { - parameters: { - path: { - /** @description Unique key of model */ - key: string; - }; - }; - requestBody: { - content: { - "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; - }; - }; - responses: { - /** @description The model was updated successfully */ - 200: { - content: { - "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; - }; - }; - /** @description Bad request */ - 400: { - content: never; - }; - /** @description The model could not be found */ - 404: { - content: never; - }; - /** @description There is already a model corresponding to the new name */ - 409: { - content: never; - }; - /** @description Validation Error */ - 422: { - content: { - "application/json": components["schemas"]["HTTPValidationError"]; - }; - }; - }; - }; - /** - * List Model Summary - * @description Gets a page of model summary data. - */ - list_model_summary: { - parameters: { - query?: { - /** @description The page to get */ - page?: number; - /** @description The number of models per page */ - per_page?: number; - /** @description The attribute to order by */ - order_by?: components["schemas"]["ModelRecordOrderBy"]; - }; - }; - responses: { - /** @description Successful Response */ - 200: { - content: { - "application/json": components["schemas"]["PaginatedResults_ModelSummary_"]; - }; - }; - /** @description Validation Error */ - 422: { - content: { - "application/json": components["schemas"]["HTTPValidationError"]; - }; + /** @description Validation Error */ + 422: { + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; }; }; }; @@ -11929,7 +11435,7 @@ export type operations = { }; }; responses: { - /** @description Success */ + /** @description The model metadata was retrieved successfully */ 200: { content: { "application/json": (components["schemas"]["BaseMetadata"] | components["schemas"]["HuggingFaceMetadata"] | components["schemas"]["CivitaiMetadata"]) | null; @@ -11980,7 +11486,7 @@ export type operations = { /** @description Successful Response */ 200: { content: { - "application/json": components["schemas"]["invokeai__app__api__routers__model_records__ModelsList"]; + "application/json": components["schemas"]["ModelsList"]; }; }; /** @description Validation Error */ @@ -11998,14 +11504,26 @@ export type operations = { add_model_record: { requestBody: { content: { - "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; + /** + * @example { + * "path": "/path/to/model", + * "name": "model_name", + * "base": "sd-1", + * "type": "main", + * "format": "checkpoint", + * "config": "configs/stable-diffusion/v1-inference.yaml", + * "description": "Model description", + * "variant": "normal" + * } + */ + "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"] | components["schemas"]["ONNXSDXLConfig"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; }; }; responses: { /** @description The model added successfully */ 201: { content: { - "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; + "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"] | components["schemas"]["ONNXSDXLConfig"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; }; }; /** @description There is already a model corresponding to this path or repo_id */ @@ -12025,22 +11543,81 @@ export type operations = { }; }; /** - * List Model Install Jobs - * @description Return list of model install jobs. + * Heuristic Import + * @description Install a model using a string identifier. + * + * `source` can be any of the following. + * + * 1. A path on the local filesystem ('C:\users\fred\model.safetensors') + * 2. A Url pointing to a single downloadable model file + * 3. A HuggingFace repo_id with any of the following formats: + * - model/name + * - model/name:fp16:vae + * - model/name::vae -- use default precision + * - model/name:fp16:path/to/model.safetensors + * - model/name::path/to/model.safetensors + * + * `config` is an optional dict containing model configuration values that will override + * the ones that are probed automatically. + * + * `access_token` is an optional access token for use with Urls that require + * authentication. + * + * Models will be downloaded, probed, configured and installed in a + * series of background threads. The return object has `status` attribute + * that can be used to monitor progress. + * + * See the documentation for `import_model_record` for more information on + * interpreting the job information returned by this route. */ - list_model_install_jobs: { + heuristic_import_model: { + parameters: { + query: { + source: string; + access_token?: string | null; + }; + }; + requestBody?: { + content: { + /** + * @example { + * "name": "modelT", + * "description": "antique cars" + * } + */ + "application/json": Record | null; + }; + }; responses: { - /** @description Successful Response */ - 200: { + /** @description The model imported successfully */ + 201: { content: { - "application/json": components["schemas"]["ModelInstallJob"][]; + "application/json": components["schemas"]["ModelInstallJob"]; }; }; + /** @description There is already a model corresponding to this path or repo_id */ + 409: { + content: never; + }; + /** @description Unrecognized file/folder format */ + 415: { + content: never; + }; + /** @description Validation Error */ + 422: { + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + /** @description The model appeared to import successfully, but could not be found in the model manager */ + 424: { + content: never; + }; }; }; /** * Import Model - * @description Add a model using its local path, repo_id, or remote URL. + * @description Install a model using its local path, repo_id, or remote URL. * * Models will be downloaded, probed, configured and installed in a * series of background threads. The return object has `status` attribute @@ -12051,32 +11628,38 @@ export type operations = { * appropriate value: * * * To install a local path using LocalModelSource, pass a source of form: - * `{ + * ``` + * { * "type": "local", * "path": "/path/to/model", * "inplace": false - * }` - * The "inplace" flag, if true, will register the model in place in its - * current filesystem location. Otherwise, the model will be copied - * into the InvokeAI models directory. + * } + * ``` + * The "inplace" flag, if true, will register the model in place in its + * current filesystem location. Otherwise, the model will be copied + * into the InvokeAI models directory. * * * To install a HuggingFace repo_id using HFModelSource, pass a source of form: - * `{ + * ``` + * { * "type": "hf", * "repo_id": "stabilityai/stable-diffusion-2.0", * "variant": "fp16", * "subfolder": "vae", * "access_token": "f5820a918aaf01" - * }` - * The `variant`, `subfolder` and `access_token` fields are optional. + * } + * ``` + * The `variant`, `subfolder` and `access_token` fields are optional. * * * To install a remote model using an arbitrary URL, pass: - * `{ + * ``` + * { * "type": "url", * "url": "http://www.civitai.com/models/123456", * "access_token": "f5820a918aaf01" - * }` - * The `access_token` field is optonal + * } + * ``` + * The `access_token` field is optonal * * The model's configuration record will be probed and filled in * automatically. To override the default guesses, pass "metadata" @@ -12085,19 +11668,19 @@ export type operations = { * Installation occurs in the background. Either use list_model_install_jobs() * to poll for completion, or listen on the event bus for the following events: * - * "model_install_running" - * "model_install_completed" - * "model_install_error" + * * "model_install_running" + * * "model_install_completed" + * * "model_install_error" * * On successful completion, the event's payload will contain the field "key" * containing the installed ID of the model. On an error, the event's payload * will contain the fields "error_type" and "error" describing the nature of the * error and its traceback, respectively. */ - import_model_record: { + import_model: { requestBody: { content: { - "application/json": components["schemas"]["Body_import_model_record"]; + "application/json": components["schemas"]["Body_import_model"]; }; }; responses: { @@ -12127,6 +11710,37 @@ export type operations = { }; }; }; + /** + * List Model Install Jobs + * @description Return the list of model install jobs. + * + * Install jobs have a numeric `id`, a `status`, and other fields that provide information on + * the nature of the job and its progress. The `status` is one of: + * + * * "waiting" -- Job is waiting in the queue to run + * * "downloading" -- Model file(s) are downloading + * * "running" -- Model has downloaded and the model probing and registration process is running + * * "completed" -- Installation completed successfully + * * "error" -- An error occurred. Details will be in the "error_type" and "error" fields. + * * "cancelled" -- Job was cancelled before completion. + * + * Once completed, information about the model such as its size, base + * model, type, and metadata can be retrieved from the `config_out` + * field. For multi-file models such as diffusers, information on individual files + * can be retrieved from `download_parts`. + * + * See the example and schema below for more information. + */ + list_model_install_jobs: { + responses: { + /** @description Successful Response */ + 200: { + content: { + "application/json": components["schemas"]["ModelInstallJob"][]; + }; + }; + }; + }; /** * Prune Model Install Jobs * @description Prune all completed and errored jobs from the install job list. @@ -12151,7 +11765,8 @@ export type operations = { }; /** * Get Model Install Job - * @description Return model install job corresponding to the given source. + * @description Return model install job corresponding to the given source. See the documentation for 'List Model Install Jobs' + * for information on the format of the return value. */ get_model_install_job: { parameters: { @@ -12234,16 +11849,59 @@ export type operations = { }; }; }; + /** + * Convert Model + * @description Permanently convert a model into diffusers format, replacing the safetensors version. + * Note that during the conversion process the key and model hash will change. + * The return value is the model configuration for the converted model. + */ + convert_model: { + parameters: { + path: { + /** @description Unique key of the safetensors main model to convert to diffusers format. */ + key: string; + }; + }; + responses: { + /** @description Model converted successfully */ + 200: { + content: { + "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"] | components["schemas"]["ONNXSDXLConfig"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; + }; + }; + /** @description Bad request */ + 400: { + content: never; + }; + /** @description Model not found */ + 404: { + content: never; + }; + /** @description There is already a model registered at this location */ + 409: { + content: never; + }; + /** @description Validation Error */ + 422: { + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; /** * Merge - * @description Merge diffusers models. - * - * keys: List of 2-3 model keys to merge together. All models must use the same base type. - * merged_model_name: Name for the merged model [Concat model names] - * alpha: Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5] - * force: If true, force the merge even if the models were generated by different versions of the diffusers library [False] - * interp: Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum] - * merge_dest_directory: Specify a directory to store the merged model in [models directory] + * @description Merge diffusers models. The process is controlled by a set parameters provided in the body of the request. + * ``` + * Argument Description [default] + * -------- ---------------------- + * keys List of 2-3 model keys to merge together. All models must use the same base type. + * merged_model_name Name for the merged model [Concat model names] + * alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5] + * force If true, force the merge even if the models were generated by different versions of the diffusers library [False] + * interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum] + * merge_dest_directory Specify a directory to store the merged model in [models directory] + * ``` */ merge: { requestBody: { @@ -12252,12 +11910,24 @@ export type operations = { }; }; responses: { - /** @description Successful Response */ + /** @description Model converted successfully */ 200: { content: { - "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; + "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"] | components["schemas"]["ONNXSDXLConfig"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; }; }; + /** @description Bad request */ + 400: { + content: never; + }; + /** @description Model not found */ + 404: { + content: never; + }; + /** @description There is already a model registered at this location */ + 409: { + content: never; + }; /** @description Validation Error */ 422: { content: { From 2b1ff8d196d495c307ed889c3db0bd8654cec181 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 15 Feb 2024 22:16:11 +1100 Subject: [PATCH 098/100] tests(ui): enable vitest type testing This is useful for the zod schemas and types we have created to match the backend. --- invokeai/frontend/web/.gitignore | 3 +++ invokeai/frontend/web/package.json | 1 + invokeai/frontend/web/pnpm-lock.yaml | 7 +++++++ invokeai/frontend/web/vite.config.mts | 5 ++++- 4 files changed, 15 insertions(+), 1 deletion(-) diff --git a/invokeai/frontend/web/.gitignore b/invokeai/frontend/web/.gitignore index 8e7ebc76a1f..3e8a372bc77 100644 --- a/invokeai/frontend/web/.gitignore +++ b/invokeai/frontend/web/.gitignore @@ -41,3 +41,6 @@ stats.html # Yalc .yalc yalc.lock + +# vitest +tsconfig.vitest-temp.json \ No newline at end of file diff --git a/invokeai/frontend/web/package.json b/invokeai/frontend/web/package.json index b2838e538ce..cea13350d26 100644 --- a/invokeai/frontend/web/package.json +++ b/invokeai/frontend/web/package.json @@ -154,6 +154,7 @@ "rollup-plugin-visualizer": "^5.12.0", "storybook": "^7.6.10", "ts-toolbelt": "^9.6.0", + "tsafe": "^1.6.6", "typescript": "^5.3.3", "vite": "^5.0.12", "vite-plugin-css-injected-by-js": "^3.3.1", diff --git a/invokeai/frontend/web/pnpm-lock.yaml b/invokeai/frontend/web/pnpm-lock.yaml index f3bf68cf1da..0ec2e47a0cd 100644 --- a/invokeai/frontend/web/pnpm-lock.yaml +++ b/invokeai/frontend/web/pnpm-lock.yaml @@ -300,6 +300,9 @@ devDependencies: ts-toolbelt: specifier: ^9.6.0 version: 9.6.0 + tsafe: + specifier: ^1.6.6 + version: 1.6.6 typescript: specifier: ^5.3.3 version: 5.3.3 @@ -13505,6 +13508,10 @@ packages: resolution: {integrity: sha512-nsZd8ZeNUzukXPlJmTBwUAuABDe/9qtVDelJeT/qW0ow3ZS3BsQJtNkan1802aM9Uf68/Y8ljw86Hu0h5IUW3w==} dev: true + /tsafe@1.6.6: + resolution: {integrity: sha512-gzkapsdbMNwBnTIjgO758GujLCj031IgHK/PKr2mrmkCSJMhSOR5FeOuSxKLMUoYc0vAA4RGEYYbjt/v6afD3g==} + dev: true + /tsconfck@3.0.1(typescript@5.3.3): resolution: {integrity: sha512-7ppiBlF3UEddCLeI1JRx5m2Ryq+xk4JrZuq4EuYXykipebaq1dV0Fhgr1hb7CkmHt32QSgOZlcqVLEtHBG4/mg==} engines: {node: ^18 || >=20} diff --git a/invokeai/frontend/web/vite.config.mts b/invokeai/frontend/web/vite.config.mts index 325c6467dee..f4dbae71232 100644 --- a/invokeai/frontend/web/vite.config.mts +++ b/invokeai/frontend/web/vite.config.mts @@ -84,7 +84,10 @@ export default defineConfig(({ mode }) => { }, }, test: { - // + typecheck: { + enabled: true, + ignoreSourceErrors: true, + }, }, }; }); From 019898c7beb30459a541a013bdb518df9689ef17 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 15 Feb 2024 22:16:55 +1100 Subject: [PATCH 099/100] tests(ui): add type tests --- .../src/features/nodes/types/common.test-d.ts | 69 +++++++++++++++++++ .../features/nodes/types/workflow.test-d.ts | 18 +++++ 2 files changed, 87 insertions(+) create mode 100644 invokeai/frontend/web/src/features/nodes/types/common.test-d.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/workflow.test-d.ts diff --git a/invokeai/frontend/web/src/features/nodes/types/common.test-d.ts b/invokeai/frontend/web/src/features/nodes/types/common.test-d.ts new file mode 100644 index 00000000000..7f28e864a13 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/common.test-d.ts @@ -0,0 +1,69 @@ +import type { + BaseModel, + BoardField, + Classification, + CLIPField, + ColorField, + ControlField, + ControlNetModelField, + ImageField, + ImageOutput, + IPAdapterField, + IPAdapterModelField, + LoraInfo, + LoRAModelField, + MainModelField, + ModelInfo, + ModelType, + ProgressImage, + SchedulerField, + SDXLRefinerModelField, + SubModelType, + T2IAdapterField, + T2IAdapterModelField, + UNetField, + VAEField, +} from 'features/nodes/types/common'; +import type { S } from 'services/api/types'; +import type { Equals, Extends } from 'tsafe'; +import { assert } from 'tsafe'; +import { describe, test } from 'vitest'; + +/** + * These types originate from the server and are recreated as zod schemas manually, for use at runtime. + * The tests ensure that the types are correctly recreated. + */ + +describe('Common types', () => { + // Complex field types + test('ImageField', () => assert>()); + test('BoardField', () => assert>()); + test('ColorField', () => assert>()); + test('SchedulerField', () => assert>>()); + test('UNetField', () => assert>()); + test('CLIPField', () => assert>()); + test('MainModelField', () => assert>()); + test('SDXLRefinerModelField', () => assert>()); + test('VAEField', () => assert>()); + test('ControlField', () => assert>()); + // @ts-expect-error TODO(psyche): fix types + test('IPAdapterField', () => assert>()); + test('T2IAdapterField', () => assert>()); + test('LoRAModelField', () => assert>()); + test('ControlNetModelField', () => assert>()); + test('IPAdapterModelField', () => assert>()); + test('T2IAdapterModelField', () => assert>()); + + // Model component types + test('BaseModel', () => assert>()); + test('ModelType', () => assert>()); + test('SubModelType', () => assert>()); + test('ModelInfo', () => assert>()); + + // Misc types + test('LoraInfo', () => assert>()); + // @ts-expect-error TODO(psyche): There is no `ProgressImage` in the server types yet + test('ProgressImage', () => assert>()); + test('ImageOutput', () => assert>()); + test('Classification', () => assert>()); +}); diff --git a/invokeai/frontend/web/src/features/nodes/types/workflow.test-d.ts b/invokeai/frontend/web/src/features/nodes/types/workflow.test-d.ts new file mode 100644 index 00000000000..7cb1ea230ce --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/workflow.test-d.ts @@ -0,0 +1,18 @@ +import type { WorkflowCategory, WorkflowV3, XYPosition } from 'features/nodes/types/workflow'; +import type * as ReactFlow from 'reactflow'; +import type { S } from 'services/api/types'; +import type { Equals, Extends } from 'tsafe'; +import { assert } from 'tsafe'; +import { describe, test } from 'vitest'; + +/** + * These types originate from the server and are recreated as zod schemas manually, for use at runtime. + * The tests ensure that the types are correctly recreated. + */ + +describe('Workflow types', () => { + test('XYPosition', () => assert>()); + test('WorkflowCategory', () => assert>()); + // @ts-expect-error TODO(psyche): Need to revise server types! + test('WorkflowV3', () => assert>()); +}); From 0c8112cf2811b3feabb2c246b94cce15466693a3 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 15 Feb 2024 22:17:16 +1100 Subject: [PATCH 100/100] fix(ui): update model types --- .../web/src/features/nodes/types/common.ts | 31 ++++++++++--------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/types/common.ts b/invokeai/frontend/web/src/features/nodes/types/common.ts index b5244743799..ef579fce8cf 100644 --- a/invokeai/frontend/web/src/features/nodes/types/common.ts +++ b/invokeai/frontend/web/src/features/nodes/types/common.ts @@ -52,27 +52,29 @@ export type SchedulerField = z.infer; // #region Model-related schemas export const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']); -export const zModelType = z.enum(['main', 'vae', 'lora', 'controlnet', 'embedding']); +export const zModelType = z.enum([ + 'main', + 'vae', + 'lora', + 'controlnet', + 'embedding', + 'ip_adapter', + 'clip_vision', + 't2i_adapter', + 'onnx', // TODO(psyche): Remove this when removed from backend +]); export const zModelName = z.string().min(3); export const zModelIdentifier = z.object({ - model_name: zModelName, - base_model: zBaseModel, + key: z.string().min(1), }); export type BaseModel = z.infer; export type ModelType = z.infer; export type ModelIdentifier = z.infer; -export const zMainModelField = z.object({ - model_name: zModelName, - base_model: zBaseModel, - model_type: z.literal('main'), -}); -export const zSDXLRefinerModelField = z.object({ - model_name: z.string().min(1), - base_model: z.literal('sdxl-refiner'), - model_type: z.literal('main'), -}); +export const zMainModelField = zModelIdentifier; export type MainModelField = z.infer; + +export const zSDXLRefinerModelField = zModelIdentifier; export type SDXLRefinerModelField = z.infer; export const zSubModelType = z.enum([ @@ -92,8 +94,7 @@ export type SubModelType = z.infer; export const zVAEModelField = zModelIdentifier; export const zModelInfo = zModelIdentifier.extend({ - model_type: zModelType, - submodel: zSubModelType.optional(), + submodel_type: zSubModelType.nullish(), }); export type ModelInfo = z.infer;